mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
add new provider Solar (#6884)
This commit is contained in:
parent
541bf1db5a
commit
2e941bb91c
|
@ -6,6 +6,7 @@
|
||||||
- nvidia
|
- nvidia
|
||||||
- nvidia_nim
|
- nvidia_nim
|
||||||
- cohere
|
- cohere
|
||||||
|
- upstage
|
||||||
- bedrock
|
- bedrock
|
||||||
- togetherai
|
- togetherai
|
||||||
- openrouter
|
- openrouter
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
<svg width="500" height="162" viewBox="0 0 500 162" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M42.5305 42L40.0623 50.518H16.3342L18.7422 42H42.5305ZM38.0367 59.8974L29.007 92.1937C27.2612 98.4243 29.007 102.277 33.9432 102.277C38.8795 102.277 43.575 98.5146 52.8455 88.5518L61.1227 59.8974H71.9584L61.1227 98.0631C60.28 100.862 61.1227 102.277 63.5909 102.277C66.1596 102.277 71.0349 100.118 82.6087 93.5465L89.8095 68.4768H77.9805L80.4486 59.8986H103.264L101.006 67.634L100.404 69.6808C102.382 67.3205 104.843 65.4115 107.621 64.0824C110.398 62.7532 113.429 62.0348 116.507 61.9754C128.848 61.9754 136.945 71.878 136.945 85.9042C136.945 102.338 125.778 114.769 111.029 114.769C106.921 114.842 102.895 113.609 99.5315 111.248L98.9596 110.826L98.7188 110.616L97.7255 109.743C94.829 106.928 92.7933 103.347 91.8562 99.4187L81.9837 133.762H71.0577L79.4382 104.585C69.4379 110.003 64.3802 111.668 60.1295 111.668C51.4308 111.668 47.9092 106.28 50.4977 96.8892L50.7385 95.9561C41.3476 107.334 35.2375 111.668 28.4351 111.668C19.4054 111.668 14.6798 103.812 17.7499 93.2472L24.7931 68.4756H25.1522L25.1604 68.4456H0L2.40793 59.8974L27.4999 59.8974H38.0367ZM97.0332 91.1414C97.0332 100.322 102.511 105.679 110.096 105.679C119.577 105.679 126.621 97.0409 126.621 85.7838C126.621 76.8143 121.564 71.0352 113.889 71.0352C104.287 71.0352 97.0332 80.5165 97.0332 91.1414ZM201.953 72.3305H216.822V63.692H201.953V46.3249H201.743L191.419 54.3312V63.692H182.871V72.3305H191.419V99.0585C191.419 102.43 191.419 114.861 205.114 114.861C208.988 115.003 212.812 113.952 216.07 111.851V103.965H215.859C213.598 105.225 211.072 105.936 208.485 106.041C203.218 106.041 201.953 102.279 201.953 94.9951V72.3305ZM142.031 100.5V109.53C147.611 113.051 154.065 114.938 160.663 114.978C171.107 114.978 179.324 109.59 179.324 99.597C179.324 89.4421 170.197 86.5376 162.598 84.1193C157.179 82.3946 152.536 80.9172 152.536 77.2334C152.536 73.6516 156.028 71.2136 161.807 71.2136C167.043 71.2212 172.142 72.8859 176.375 75.9692H176.585V66.9395C172.007 64.0811 166.722 62.5591 161.325 62.545C150.188 62.545 142.814 68.5648 142.814 77.9257C142.814 87.5253 151.332 90.2816 158.735 92.6769C164.423 94.5174 169.452 96.1448 169.452 100.5C169.452 104.292 165.569 106.7 159.549 106.7C153.327 106.458 147.332 104.292 142.393 100.5H142.031ZM266.552 79.4936V113.746H258.425L257.492 106.071C255.696 108.954 253.167 111.308 250.163 112.895C247.16 114.481 243.79 115.242 240.396 115.101C230.915 115.101 222.066 109.623 222.066 99.5095C222.066 87.801 232.6 84.1289 244.188 84.1289C251.894 84.1289 256.228 83.7075 256.228 78.7412C256.228 73.7748 251.773 71.6077 244.73 71.6077C237.667 71.573 230.852 74.2068 225.647 78.982H225.437L225.768 69.5007C231.407 65.0923 238.299 62.5844 245.453 62.3371C255.897 62.2168 266.552 66.551 266.552 79.4936ZM256.77 93.6402V88.3729C254.422 91.3828 249.697 92.045 243.466 92.045C237.236 92.045 232.42 94.3626 232.42 99.4193C232.42 104.476 237.567 106.794 242.623 106.794C244.427 106.908 246.236 106.654 247.938 106.046C249.641 105.439 251.202 104.49 252.526 103.26C253.849 102.029 254.909 100.541 255.638 98.8869C256.368 97.2331 256.753 95.4478 256.77 93.6402ZM324.577 63.6931H316.481L315.307 72.2412C313.561 69.1151 310.983 66.5344 307.859 64.7861C304.734 63.0379 301.186 62.1906 297.609 62.3386C284.756 62.3386 273.891 72.3315 273.891 87.6218C273.891 103.424 284.425 112.905 296.856 112.905C300.191 113.01 303.501 112.297 306.497 110.828C309.493 109.359 312.084 107.178 314.043 104.477V108.481C314.043 118.925 307.722 124.614 298.451 124.614C290.792 124.634 283.357 122.032 277.382 117.239H277.171V126.811C283.552 131.553 291.347 134.003 299.294 133.764C314.253 133.764 324.577 124.704 324.577 106.494V63.6931ZM299.806 71.3984C309.287 71.3984 314.584 79.2844 314.584 87.5014C314.584 96.0496 309.287 103.845 299.806 103.845C289.602 103.845 284.756 95.9292 284.756 87.5014C284.756 79.0737 289.271 71.3984 299.806 71.3984ZM348.753 91.8308C345.705 91.8308 343.327 92.2803 343.327 95.3068C343.327 100.461 349.606 105.795 359.329 105.795C362.879 105.975 366.43 105.446 369.766 104.241C373.103 103.036 376.157 101.179 378.745 98.7828H378.958V107.773C373.384 112.554 366.207 115.138 358.811 115.024C343.023 115.024 332.02 104.416 332.02 88.5645C331.884 85.1244 332.457 81.6929 333.705 78.4762C334.953 75.2596 336.85 72.3244 339.282 69.8471C341.713 67.3697 344.629 65.4014 347.854 64.0606C351.08 62.7198 354.547 62.034 358.049 62.0447C363.658 61.8285 369.142 63.7075 373.4 67.3041C377.658 70.9007 380.373 75.9484 381 81.4326V91.8308H348.753ZM350.246 84.1895H370.424C370.49 82.458 370.188 80.7321 369.538 79.1217C368.887 77.5114 367.902 76.052 366.646 74.8367C365.39 73.6214 363.89 72.677 362.242 72.0636C360.594 71.4502 358.833 71.1813 357.074 71.2742C347.93 71.2742 343.114 77.567 342.474 86.4968C344.704 84.8454 347.459 84.0275 350.246 84.1895Z" fill="#805CFB"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M487.652 28L486.017 33.7297H497.393L499.014 28H487.652Z" fill="#805CFB"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M471.995 38.027L470.346 43.7568H494.549L496.17 38.027H471.995Z" fill="#805CFB"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M457.548 48.054L455.898 53.7838H491.705L493.326 48.054H457.548Z" fill="#805CFB"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M413.494 58.0811L411.844 63.8108H488.861L490.482 58.0811H413.494Z" fill="#805CFB"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M418.03 68.1081L416.38 73.8378H486.017L487.638 68.1081H418.03Z" fill="#805CFB"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M420.348 134L421.997 128.27H410.607L409 134H420.348Z" fill="#805CFB"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M436.019 123.973L437.668 118.243H413.451L411.844 123.973H436.019Z" fill="#805CFB"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M450.466 113.946L452.116 108.216H416.295L414.688 113.946H450.466Z" fill="#805CFB"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M494.521 103.919L496.17 98.1891H419.139L417.532 103.919H494.521Z" fill="#805CFB"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M489.97 93.8919L491.62 88.1622H421.983L420.376 93.8919H489.97Z" fill="#805CFB"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M483.372 83.8649L485.021 78.1351H419.679L418.073 83.8649H483.372Z" fill="#805CFB"/>
|
||||||
|
</svg>
|
After Width: | Height: | Size: 6.2 KiB |
|
@ -0,0 +1,3 @@
|
||||||
|
<svg width="137" height="163" viewBox="0 0 137 163" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M104.652 29.325L103.017 35.0547H114.393L116.014 29.325H104.652ZM88.9956 39.352L87.346 45.0817H111.549L113.17 39.352H88.9956ZM72.8984 55.1088L74.5479 49.379H110.326L108.705 55.1088H72.8984ZM30.4937 59.4061L28.8442 65.1358H105.861L107.482 59.4061H30.4937ZM33.3802 75.1628L35.0298 69.4331H104.638L103.017 75.1628H33.3802ZM37.3478 135.325L38.9973 129.595H27.6069L26 135.325H37.3478ZM54.6682 119.568L53.0186 125.298H28.8442L30.4511 119.568H54.6682ZM67.4662 115.271L69.1157 109.541H33.2949L31.688 115.271H67.4662ZM113.17 99.5142L111.521 105.244H34.5322L36.1391 99.5142H113.17ZM106.97 95.2169L108.62 89.4871H38.9832L37.3763 95.2169H106.97ZM102.021 79.4601L100.372 85.1898H35.0724L36.6793 79.4601H102.021Z" fill="#805CFB"/>
|
||||||
|
</svg>
|
After Width: | Height: | Size: 872 B |
57
api/core/model_runtime/model_providers/upstage/_common.py
Normal file
57
api/core/model_runtime/model_providers/upstage/_common.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from httpx import Timeout
|
||||||
|
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _CommonUpstage:
|
||||||
|
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
|
||||||
|
"""
|
||||||
|
Transform credentials to kwargs for model instance
|
||||||
|
|
||||||
|
:param credentials:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
credentials_kwargs = {
|
||||||
|
"api_key": credentials['upstage_api_key'],
|
||||||
|
"base_url": "https://api.upstage.ai/v1/solar",
|
||||||
|
"timeout": Timeout(315.0, read=300.0, write=20.0, connect=10.0),
|
||||||
|
"max_retries": 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return credentials_kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||||
|
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||||
|
InvokeRateLimitError: [openai.RateLimitError],
|
||||||
|
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
openai.BadRequestError,
|
||||||
|
openai.NotFoundError,
|
||||||
|
openai.UnprocessableEntityError,
|
||||||
|
openai.APIError,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
- soloar-1-mini-chat
|
575
api/core/model_runtime/model_providers/upstage/llm/llm.py
Normal file
575
api/core/model_runtime/model_providers/upstage/llm/llm.py
Normal file
|
@ -0,0 +1,575 @@
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
|
from openai import OpenAI, Stream
|
||||||
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
||||||
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
||||||
|
from openai.types.chat.chat_completion_message import FunctionCall
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageContentType,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.model_runtime.model_providers.upstage._common import _CommonUpstage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
UPSTAGE_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
|
if you are not sure about the structure.
|
||||||
|
|
||||||
|
<instructions>
|
||||||
|
{{instructions}}
|
||||||
|
</instructions>
|
||||||
|
"""
|
||||||
|
|
||||||
|
class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
|
||||||
|
"""
|
||||||
|
Model class for Upstage large language model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._chat_generate(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
def _code_block_mode_wrapper(self,
|
||||||
|
model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Code block mode wrapper for invoking large language model
|
||||||
|
"""
|
||||||
|
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
|
||||||
|
stop = stop or []
|
||||||
|
self._transform_chat_json_prompts(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
response_format=model_parameters['response_format']
|
||||||
|
)
|
||||||
|
model_parameters.pop('response_format')
|
||||||
|
|
||||||
|
return self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||||
|
stream: bool = True, user: str | None = None, response_format: str = 'JSON') -> None:
|
||||||
|
"""
|
||||||
|
Transform json prompts
|
||||||
|
"""
|
||||||
|
if stop is None:
|
||||||
|
stop = []
|
||||||
|
if "```\n" not in stop:
|
||||||
|
stop.append("```\n")
|
||||||
|
if "\n```" not in stop:
|
||||||
|
stop.append("\n```")
|
||||||
|
|
||||||
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
|
prompt_messages[0] = SystemPromptMessage(
|
||||||
|
content=UPSTAGE_BLOCK_MODE_PROMPT
|
||||||
|
.replace("{{instructions}}", prompt_messages[0].content)
|
||||||
|
.replace("{{block}}", response_format)
|
||||||
|
)
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
|
||||||
|
else:
|
||||||
|
prompt_messages.insert(0, SystemPromptMessage(
|
||||||
|
content=UPSTAGE_BLOCK_MODE_PROMPT
|
||||||
|
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||||
|
.replace("{{block}}", response_format)
|
||||||
|
))
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return self._num_tokens_from_messages(model, prompt_messages, tools)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
client = OpenAI(**credentials_kwargs)
|
||||||
|
|
||||||
|
client.chat.completions.create(
|
||||||
|
messages=[{"role": "user", "content": "ping"}],
|
||||||
|
model=model,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=10,
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise CredentialsValidateFailedError(str(e))
|
||||||
|
|
||||||
|
def _chat_generate(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||||
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
client = OpenAI(**credentials_kwargs)
|
||||||
|
|
||||||
|
extra_model_kwargs = {}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
extra_model_kwargs["functions"] = [{
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.parameters
|
||||||
|
} for tool in tools]
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
extra_model_kwargs["stop"] = stop
|
||||||
|
|
||||||
|
if user:
|
||||||
|
extra_model_kwargs["user"] = user
|
||||||
|
|
||||||
|
# chat model
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
**model_parameters,
|
||||||
|
**extra_model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
|
||||||
|
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None) -> LLMResult:
|
||||||
|
"""
|
||||||
|
Handle llm chat response
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials
|
||||||
|
:param response: response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return: llm response
|
||||||
|
"""
|
||||||
|
assistant_message = response.choices[0].message
|
||||||
|
# assistant_message_tool_calls = assistant_message.tool_calls
|
||||||
|
assistant_message_function_call = assistant_message.function_call
|
||||||
|
|
||||||
|
# extract tool calls from response
|
||||||
|
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||||
|
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||||
|
tool_calls = [function_call] if function_call else []
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=assistant_message.content,
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate num tokens
|
||||||
|
if response.usage:
|
||||||
|
# transform usage
|
||||||
|
prompt_tokens = response.usage.prompt_tokens
|
||||||
|
completion_tokens = response.usage.completion_tokens
|
||||||
|
else:
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(model, [assistant_prompt_message])
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
# transform response
|
||||||
|
response = LLMResult(
|
||||||
|
model=response.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
usage=usage,
|
||||||
|
system_fingerprint=response.system_fingerprint,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk],
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
|
||||||
|
"""
|
||||||
|
Handle llm chat stream response
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param response: response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return: llm response chunk generator
|
||||||
|
"""
|
||||||
|
full_assistant_content = ''
|
||||||
|
delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None
|
||||||
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
final_tool_calls = []
|
||||||
|
final_chunk = LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=''),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if len(chunk.choices) == 0:
|
||||||
|
if chunk.usage:
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = chunk.usage.prompt_tokens
|
||||||
|
completion_tokens = chunk.usage.completion_tokens
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk.choices[0]
|
||||||
|
has_finish_reason = delta.finish_reason is not None
|
||||||
|
|
||||||
|
if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \
|
||||||
|
delta.delta.function_call is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# assistant_message_tool_calls = delta.delta.tool_calls
|
||||||
|
assistant_message_function_call = delta.delta.function_call
|
||||||
|
|
||||||
|
# extract tool calls from response
|
||||||
|
if delta_assistant_message_function_call_storage is not None:
|
||||||
|
# handle process of stream function call
|
||||||
|
if assistant_message_function_call:
|
||||||
|
# message has not ended ever
|
||||||
|
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# message has ended
|
||||||
|
assistant_message_function_call = delta_assistant_message_function_call_storage
|
||||||
|
delta_assistant_message_function_call_storage = None
|
||||||
|
else:
|
||||||
|
if assistant_message_function_call:
|
||||||
|
# start of stream function call
|
||||||
|
delta_assistant_message_function_call_storage = assistant_message_function_call
|
||||||
|
if delta_assistant_message_function_call_storage.arguments is None:
|
||||||
|
delta_assistant_message_function_call_storage.arguments = ''
|
||||||
|
if not has_finish_reason:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||||
|
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||||
|
tool_calls = [function_call] if function_call else []
|
||||||
|
if tool_calls:
|
||||||
|
final_tool_calls.extend(tool_calls)
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=delta.delta.content if delta.delta.content else '',
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
full_assistant_content += delta.delta.content if delta.delta.content else ''
|
||||||
|
|
||||||
|
if has_finish_reason:
|
||||||
|
final_chunk = LLMResultChunk(
|
||||||
|
model=chunk.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
system_fingerprint=chunk.system_fingerprint,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=delta.index,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
finish_reason=delta.finish_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=chunk.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
system_fingerprint=chunk.system_fingerprint,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=delta.index,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not prompt_tokens:
|
||||||
|
prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
|
||||||
|
|
||||||
|
if not completion_tokens:
|
||||||
|
full_assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=full_assistant_content,
|
||||||
|
tool_calls=final_tool_calls
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message])
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
final_chunk.delta.usage = usage
|
||||||
|
|
||||||
|
yield final_chunk
|
||||||
|
|
||||||
|
def _extract_response_tool_calls(self,
|
||||||
|
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
|
||||||
|
-> list[AssistantPromptMessage.ToolCall]:
|
||||||
|
"""
|
||||||
|
Extract tool calls from response
|
||||||
|
|
||||||
|
:param response_tool_calls: response tool calls
|
||||||
|
:return: list of tool calls
|
||||||
|
"""
|
||||||
|
tool_calls = []
|
||||||
|
if response_tool_calls:
|
||||||
|
for response_tool_call in response_tool_calls:
|
||||||
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=response_tool_call.function.name,
|
||||||
|
arguments=response_tool_call.function.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=response_tool_call.id,
|
||||||
|
type=response_tool_call.type,
|
||||||
|
function=function
|
||||||
|
)
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
|
||||||
|
-> AssistantPromptMessage.ToolCall:
|
||||||
|
"""
|
||||||
|
Extract function call from response
|
||||||
|
|
||||||
|
:param response_function_call: response function call
|
||||||
|
:return: tool call
|
||||||
|
"""
|
||||||
|
tool_call = None
|
||||||
|
if response_function_call:
|
||||||
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=response_function_call.name,
|
||||||
|
arguments=response_function_call.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=response_function_call.name,
|
||||||
|
type="function",
|
||||||
|
function=function
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_call
|
||||||
|
|
||||||
|
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||||
|
"""
|
||||||
|
Convert PromptMessage to dict for Upstage API
|
||||||
|
"""
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
message = cast(UserPromptMessage, message)
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
else:
|
||||||
|
sub_messages = []
|
||||||
|
for message_content in message.content:
|
||||||
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
|
message_content = cast(TextPromptMessageContent, message_content)
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "text",
|
||||||
|
"text": message_content.data
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
|
message_content = cast(ImagePromptMessageContent, message_content)
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": message_content.data,
|
||||||
|
"detail": message_content.detail.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
|
||||||
|
message_dict = {"role": "user", "content": sub_messages}
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
message = cast(AssistantPromptMessage, message)
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if message.tool_calls:
|
||||||
|
# message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
||||||
|
# message.tool_calls]
|
||||||
|
function_call = message.tool_calls[0]
|
||||||
|
message_dict["function_call"] = {
|
||||||
|
"name": function_call.function.name,
|
||||||
|
"arguments": function_call.function.arguments,
|
||||||
|
}
|
||||||
|
elif isinstance(message, SystemPromptMessage):
|
||||||
|
message = cast(SystemPromptMessage, message)
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
# message_dict = {
|
||||||
|
# "role": "tool",
|
||||||
|
# "content": message.content,
|
||||||
|
# "tool_call_id": message.tool_call_id
|
||||||
|
# }
|
||||||
|
message_dict = {
|
||||||
|
"role": "function",
|
||||||
|
"content": message.content,
|
||||||
|
"name": message.tool_call_id
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
if message.name:
|
||||||
|
message_dict["name"] = message.name
|
||||||
|
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
def _get_tokenizer(self) -> Tokenizer:
|
||||||
|
return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer")
|
||||||
|
|
||||||
|
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Calculate num tokens for solar with Huggingface Solar tokenizer.
|
||||||
|
Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-1-mini-tokenizer
|
||||||
|
"""
|
||||||
|
tokenizer = self._get_tokenizer()
|
||||||
|
tokens_per_message = 5 # <|im_start|>{role}\n{message}<|im_end|>
|
||||||
|
tokens_prefix = 1 # <|startoftext|>
|
||||||
|
tokens_suffix = 3 # <|im_start|>assistant\n
|
||||||
|
|
||||||
|
num_tokens = 0
|
||||||
|
num_tokens += tokens_prefix
|
||||||
|
|
||||||
|
messages_dict = [self._convert_prompt_message_to_dict(message) for message in messages]
|
||||||
|
for message in messages_dict:
|
||||||
|
num_tokens += tokens_per_message
|
||||||
|
for key, value in message.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
text = ''
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict) and item['type'] == 'text':
|
||||||
|
text += item['text']
|
||||||
|
value = text
|
||||||
|
|
||||||
|
if key == "tool_calls":
|
||||||
|
for tool_call in value:
|
||||||
|
for t_key, t_value in tool_call.items():
|
||||||
|
num_tokens += len(tokenizer.encode(t_key, add_special_tokens=False))
|
||||||
|
if t_key == "function":
|
||||||
|
for f_key, f_value in t_value.items():
|
||||||
|
num_tokens += len(tokenizer.encode(f_key, add_special_tokens=False))
|
||||||
|
num_tokens += len(tokenizer.encode(f_value, add_special_tokens=False))
|
||||||
|
else:
|
||||||
|
num_tokens += len(tokenizer.encode(t_key, add_special_tokens=False))
|
||||||
|
num_tokens += len(tokenizer.encode(t_value, add_special_tokens=False))
|
||||||
|
else:
|
||||||
|
num_tokens += len(tokenizer.encode(str(value), add_special_tokens=False))
|
||||||
|
num_tokens += tokens_suffix
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
num_tokens += self._num_tokens_for_tools(tokenizer, tools)
|
||||||
|
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
def _num_tokens_for_tools(self, tokenizer: Tokenizer, tools: list[PromptMessageTool]) -> int:
|
||||||
|
"""
|
||||||
|
Calculate num tokens for tool calling with upstage tokenizer.
|
||||||
|
|
||||||
|
:param tokenizer: huggingface tokenizer
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return: number of tokens
|
||||||
|
"""
|
||||||
|
num_tokens = 0
|
||||||
|
for tool in tools:
|
||||||
|
num_tokens += len(tokenizer.encode('type'))
|
||||||
|
num_tokens += len(tokenizer.encode('function'))
|
||||||
|
|
||||||
|
# calculate num tokens for function object
|
||||||
|
num_tokens += len(tokenizer.encode('name'))
|
||||||
|
num_tokens += len(tokenizer.encode(tool.name))
|
||||||
|
num_tokens += len(tokenizer.encode('description'))
|
||||||
|
num_tokens += len(tokenizer.encode(tool.description))
|
||||||
|
parameters = tool.parameters
|
||||||
|
num_tokens += len(tokenizer.encode('parameters'))
|
||||||
|
if 'title' in parameters:
|
||||||
|
num_tokens += len(tokenizer.encode('title'))
|
||||||
|
num_tokens += len(tokenizer.encode(parameters.get("title")))
|
||||||
|
num_tokens += len(tokenizer.encode('type'))
|
||||||
|
num_tokens += len(tokenizer.encode(parameters.get("type")))
|
||||||
|
if 'properties' in parameters:
|
||||||
|
num_tokens += len(tokenizer.encode('properties'))
|
||||||
|
for key, value in parameters.get('properties').items():
|
||||||
|
num_tokens += len(tokenizer.encode(key))
|
||||||
|
for field_key, field_value in value.items():
|
||||||
|
num_tokens += len(tokenizer.encode(field_key))
|
||||||
|
if field_key == 'enum':
|
||||||
|
for enum_field in field_value:
|
||||||
|
num_tokens += 3
|
||||||
|
num_tokens += len(tokenizer.encode(enum_field))
|
||||||
|
else:
|
||||||
|
num_tokens += len(tokenizer.encode(field_key))
|
||||||
|
num_tokens += len(tokenizer.encode(str(field_value)))
|
||||||
|
if 'required' in parameters:
|
||||||
|
num_tokens += len(tokenizer.encode('required'))
|
||||||
|
for required_field in parameters['required']:
|
||||||
|
num_tokens += 3
|
||||||
|
num_tokens += len(tokenizer.encode(required_field))
|
||||||
|
|
||||||
|
return num_tokens
|
|
@ -0,0 +1,43 @@
|
||||||
|
model: solar-1-mini-chat
|
||||||
|
label:
|
||||||
|
zh_Hans: solar-1-mini-chat
|
||||||
|
en_US: solar-1-mini-chat
|
||||||
|
ko_KR: solar-1-mini-chat
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 32768
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 512
|
||||||
|
min: 1
|
||||||
|
max: 32768
|
||||||
|
- name: seed
|
||||||
|
label:
|
||||||
|
zh_Hans: 种子
|
||||||
|
en_US: Seed
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans:
|
||||||
|
如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint
|
||||||
|
响应参数来监视变化。
|
||||||
|
en_US:
|
||||||
|
If specified, model will make a best effort to sample deterministically,
|
||||||
|
such that repeated requests with the same seed and parameters should return
|
||||||
|
the same result. Determinism is not guaranteed, and you should refer to the
|
||||||
|
system_fingerprint response parameter to monitor changes in the backend.
|
||||||
|
required: false
|
||||||
|
pricing:
|
||||||
|
input: "0.5"
|
||||||
|
output: "0.5"
|
||||||
|
unit: "0.000001"
|
||||||
|
currency: USD
|
|
@ -0,0 +1,9 @@
|
||||||
|
model: solar-embedding-1-large-passage
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 4000
|
||||||
|
max_chunks: 32
|
||||||
|
pricing:
|
||||||
|
input: '0.1'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: 'USD'
|
|
@ -0,0 +1,9 @@
|
||||||
|
model: solar-embedding-1-large-query
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 4000
|
||||||
|
max_chunks: 32
|
||||||
|
pricing:
|
||||||
|
input: '0.1'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: 'USD'
|
|
@ -0,0 +1,195 @@
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from openai import OpenAI
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
from core.model_runtime.model_providers.upstage._common import _CommonUpstage
|
||||||
|
|
||||||
|
|
||||||
|
class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
|
||||||
|
"""
|
||||||
|
Model class for Upstage text embedding model.
|
||||||
|
"""
|
||||||
|
def _get_tokenizer(self) -> Tokenizer:
|
||||||
|
return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer")
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict, texts: list[str], user: str | None = None) -> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
|
||||||
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
client = OpenAI(**credentials_kwargs)
|
||||||
|
|
||||||
|
extra_model_kwargs = {}
|
||||||
|
if user:
|
||||||
|
extra_model_kwargs["user"] = user
|
||||||
|
extra_model_kwargs["encoding_format"] = "base64"
|
||||||
|
|
||||||
|
context_size = self._get_context_size(model, credentials)
|
||||||
|
max_chunks = self._get_max_chunks(model, credentials)
|
||||||
|
|
||||||
|
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||||
|
tokens = []
|
||||||
|
indices = []
|
||||||
|
used_tokens = 0
|
||||||
|
|
||||||
|
tokenizer = self._get_tokenizer()
|
||||||
|
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
token = tokenizer.encode(text, add_special_tokens=False).tokens
|
||||||
|
for j in range(0, len(token), context_size):
|
||||||
|
tokens += [token[j:j+context_size]]
|
||||||
|
indices += [i]
|
||||||
|
|
||||||
|
batched_embeddings = []
|
||||||
|
_iter = range(0, len(tokens), max_chunks)
|
||||||
|
|
||||||
|
for i in _iter:
|
||||||
|
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||||
|
model=model,
|
||||||
|
client=client,
|
||||||
|
texts=tokens[i:i+max_chunks],
|
||||||
|
extra_model_kwargs=extra_model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
used_tokens += embedding_used_tokens
|
||||||
|
batched_embeddings += embeddings_batch
|
||||||
|
|
||||||
|
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||||
|
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||||
|
|
||||||
|
for i in range(len(indices)):
|
||||||
|
results[indices[i]].append(batched_embeddings[i])
|
||||||
|
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||||
|
|
||||||
|
for i in range(len(texts)):
|
||||||
|
_result = results[i]
|
||||||
|
if len(_result) == 0:
|
||||||
|
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||||
|
model=model,
|
||||||
|
client=client,
|
||||||
|
texts=[texts[i]],
|
||||||
|
extra_model_kwargs=extra_model_kwargs,
|
||||||
|
)
|
||||||
|
used_tokens += embedding_used_tokens
|
||||||
|
average = embeddings_batch[0]
|
||||||
|
else:
|
||||||
|
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||||
|
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
tokens=used_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
tokenizer = self._get_tokenizer()
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if len(texts) == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
tokenizer = self._get_tokenizer()
|
||||||
|
|
||||||
|
total_num_tokens = 0
|
||||||
|
for text in texts:
|
||||||
|
# calculate the number of tokens in the encoded text
|
||||||
|
tokenized_text = tokenizer.encode(text)
|
||||||
|
total_num_tokens += len(tokenized_text)
|
||||||
|
|
||||||
|
return total_num_tokens
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# transform credentials to kwargs for model instance
|
||||||
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
client = OpenAI(**credentials_kwargs)
|
||||||
|
|
||||||
|
# call embedding model
|
||||||
|
self._embedding_invoke(
|
||||||
|
model=model,
|
||||||
|
client=client,
|
||||||
|
texts=['ping'],
|
||||||
|
extra_model_kwargs={}
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict) -> tuple[list[list[float]], int]:
|
||||||
|
"""
|
||||||
|
Invoke embedding model
|
||||||
|
:param model: model name
|
||||||
|
:param client: model client
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param extra_model_kwargs: extra model kwargs
|
||||||
|
:return: embeddings and used tokens
|
||||||
|
"""
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model=model,
|
||||||
|
input=texts,
|
||||||
|
**extra_model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64':
|
||||||
|
return ([list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32)) for embedding in response.data], response.usage.total_tokens)
|
||||||
|
|
||||||
|
return [data.embedding for data in response.data], response.usage.total_tokens
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
tokens=tokens,
|
||||||
|
price_type=PriceType.INPUT
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
32
api/core/model_runtime/model_providers/upstage/upstage.py
Normal file
32
api/core/model_runtime/model_providers/upstage/upstage.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UpstageProvider(ModelProvider):
|
||||||
|
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials from defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_instance = self.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
|
model_instance.validate_credentials(
|
||||||
|
model="solar-1-mini-chat",
|
||||||
|
credentials=credentials
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as e:
|
||||||
|
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||||
|
raise e
|
||||||
|
|
49
api/core/model_runtime/model_providers/upstage/upstage.yaml
Normal file
49
api/core/model_runtime/model_providers/upstage/upstage.yaml
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
provider: upstage
|
||||||
|
label:
|
||||||
|
en_US: Upstage
|
||||||
|
description:
|
||||||
|
en_US: Models provided by Upstage, such as Solar-1-mini-chat.
|
||||||
|
zh_Hans: Upstage 提供的模型,例如 Solar-1-mini-chat.
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.svg
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.svg
|
||||||
|
background: "#FFFFF"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: Get your API Key from Upstage
|
||||||
|
zh_Hans: 从 Upstage 获取 API Key
|
||||||
|
url:
|
||||||
|
en_US: https://console.upstage.ai/api-keys
|
||||||
|
supported_model_types:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
configurate_methods:
|
||||||
|
- predefined-model
|
||||||
|
model_credential_schema:
|
||||||
|
model:
|
||||||
|
label:
|
||||||
|
en_US: Model Name
|
||||||
|
zh_Hans: 模型名称
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your model name
|
||||||
|
zh_Hans: 输入模型名称
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: upstage_api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
||||||
|
provider_credential_schema:
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: upstage_api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
|
@ -4,7 +4,7 @@ set -e
|
||||||
|
|
||||||
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
|
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
|
||||||
echo "Running migrations"
|
echo "Running migrations"
|
||||||
flask upgrade-db
|
flask db upgrade
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ "${MODE}" == "worker" ]]; then
|
if [[ "${MODE}" == "worker" ]]; then
|
||||||
|
|
|
@ -73,6 +73,7 @@ quote-style = "single"
|
||||||
|
|
||||||
[tool.pytest_env]
|
[tool.pytest_env]
|
||||||
OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii"
|
OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii"
|
||||||
|
UPSTAGE_API_KEY = "up-aaaaaaaaaaaaaaaaaaaa"
|
||||||
AZURE_OPENAI_API_BASE = "https://difyai-openai.openai.azure.com"
|
AZURE_OPENAI_API_BASE = "https://difyai-openai.openai.azure.com"
|
||||||
AZURE_OPENAI_API_KEY = "xxxxb1707exxxxxxxxxxaaxxxxxf94"
|
AZURE_OPENAI_API_KEY = "xxxxb1707exxxxxxxxxxaaxxxxxf94"
|
||||||
ANTHROPIC_API_KEY = "sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz"
|
ANTHROPIC_API_KEY = "sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz"
|
||||||
|
|
245
api/tests/integration_tests/model_runtime/upstage/test_llm.py
Normal file
245
api/tests/integration_tests/model_runtime/upstage/test_llm.py
Normal file
|
@ -0,0 +1,245 @@
|
||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.model_runtime.model_providers.upstage.llm.llm import UpstageLargeLanguageModel
|
||||||
|
|
||||||
|
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||||
|
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_predefined_models():
|
||||||
|
model = UpstageLargeLanguageModel()
|
||||||
|
model_schemas = model.predefined_models()
|
||||||
|
|
||||||
|
assert len(model_schemas) >= 1
|
||||||
|
assert isinstance(model_schemas[0], AIModelEntity)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||||
|
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||||
|
model = UpstageLargeLanguageModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
# model name to gpt-3.5-turbo because of mocking
|
||||||
|
model.validate_credentials(
|
||||||
|
model='gpt-3.5-turbo',
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': 'invalid_key'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model='solar-1-mini-chat',
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||||
|
def test_invoke_chat_model(setup_openai_mock):
|
||||||
|
model = UpstageLargeLanguageModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model='solar-1-mini-chat',
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content='You are a helpful AI assistant.',
|
||||||
|
),
|
||||||
|
UserPromptMessage(
|
||||||
|
content='Hello World!'
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model_parameters={
|
||||||
|
'temperature': 0.0,
|
||||||
|
'top_p': 1.0,
|
||||||
|
'presence_penalty': 0.0,
|
||||||
|
'frequency_penalty': 0.0,
|
||||||
|
'max_tokens': 10
|
||||||
|
},
|
||||||
|
stop=['How'],
|
||||||
|
stream=False,
|
||||||
|
user="abc-123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
assert len(result.message.content) > 0
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||||
|
def test_invoke_chat_model_with_tools(setup_openai_mock):
|
||||||
|
model = UpstageLargeLanguageModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model='solar-1-mini-chat',
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content='You are a helpful AI assistant.',
|
||||||
|
),
|
||||||
|
UserPromptMessage(
|
||||||
|
content="what's the weather today in London?",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model_parameters={
|
||||||
|
'temperature': 0.0,
|
||||||
|
'max_tokens': 100
|
||||||
|
},
|
||||||
|
tools=[
|
||||||
|
PromptMessageTool(
|
||||||
|
name='get_weather',
|
||||||
|
description='Determine weather in my location',
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"c",
|
||||||
|
"f"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"location"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
PromptMessageTool(
|
||||||
|
name='get_stock_price',
|
||||||
|
description='Get the current stock price',
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"symbol": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The stock symbol"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"symbol"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
user="abc-123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
assert isinstance(result.message, AssistantPromptMessage)
|
||||||
|
assert len(result.message.tool_calls) > 0
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||||
|
def test_invoke_stream_chat_model(setup_openai_mock):
|
||||||
|
model = UpstageLargeLanguageModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model='solar-1-mini-chat',
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content='You are a helpful AI assistant.',
|
||||||
|
),
|
||||||
|
UserPromptMessage(
|
||||||
|
content='Hello World!'
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model_parameters={
|
||||||
|
'temperature': 0.0,
|
||||||
|
'max_tokens': 100
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
user="abc-123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, Generator)
|
||||||
|
|
||||||
|
for chunk in result:
|
||||||
|
assert isinstance(chunk, LLMResultChunk)
|
||||||
|
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||||
|
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||||
|
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||||
|
if chunk.delta.finish_reason is not None:
|
||||||
|
assert chunk.delta.usage is not None
|
||||||
|
assert chunk.delta.usage.completion_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = UpstageLargeLanguageModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model='solar-1-mini-chat',
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
UserPromptMessage(
|
||||||
|
content='Hello World!'
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 13
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model='solar-1-mini-chat',
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content='You are a helpful AI assistant.',
|
||||||
|
),
|
||||||
|
UserPromptMessage(
|
||||||
|
content='Hello World!'
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
PromptMessageTool(
|
||||||
|
name='get_weather',
|
||||||
|
description='Determine weather in my location',
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"c",
|
||||||
|
"f"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"location"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 106
|
|
@ -0,0 +1,23 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.upstage.upstage import UpstageProvider
|
||||||
|
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||||
|
def test_validate_provider_credentials(setup_openai_mock):
|
||||||
|
provider = UpstageProvider()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
provider.validate_provider_credentials(
|
||||||
|
credentials={}
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.validate_provider_credentials(
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||||
|
}
|
||||||
|
)
|
|
@ -0,0 +1,67 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.upstage.text_embedding.text_embedding import UpstageTextEmbeddingModel
|
||||||
|
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||||
|
def test_validate_credentials(setup_openai_mock):
|
||||||
|
model = UpstageTextEmbeddingModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model='solar-embedding-1-large-passage',
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': 'invalid_key'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model='solar-embedding-1-large-passage',
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||||
|
def test_invoke_model(setup_openai_mock):
|
||||||
|
model = UpstageTextEmbeddingModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model='solar-embedding-1-large-passage',
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'),
|
||||||
|
},
|
||||||
|
texts=[
|
||||||
|
"hello",
|
||||||
|
"world",
|
||||||
|
" ".join(["long_text"] * 100),
|
||||||
|
" ".join(["another_long_text"] * 100)
|
||||||
|
],
|
||||||
|
user="abc-123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, TextEmbeddingResult)
|
||||||
|
assert len(result.embeddings) == 4
|
||||||
|
assert result.usage.total_tokens == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = UpstageTextEmbeddingModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model='solar-embedding-1-large-passage',
|
||||||
|
credentials={
|
||||||
|
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'),
|
||||||
|
},
|
||||||
|
texts=[
|
||||||
|
"hello",
|
||||||
|
"world"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 5
|
|
@ -5,4 +5,6 @@ pytest api/tests/integration_tests/model_runtime/anthropic \
|
||||||
api/tests/integration_tests/model_runtime/azure_openai \
|
api/tests/integration_tests/model_runtime/azure_openai \
|
||||||
api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \
|
api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \
|
||||||
api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \
|
api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \
|
||||||
api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py
|
api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \
|
||||||
|
api/tests/integration_tests/model_runtime/upstage
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user