mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Add Fireworks AI as new model provider (#8428)
This commit is contained in:
parent
c8b9bdebfe
commit
0665268578
|
@ -37,3 +37,4 @@
|
||||||
- siliconflow
|
- siliconflow
|
||||||
- perfxcloud
|
- perfxcloud
|
||||||
- zhinao
|
- zhinao
|
||||||
|
- fireworks
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
<svg width="130" role="graphics-symbol" aria-label="Fireworks AI Home" viewBox="0 0 835 130" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M112.65 0L91.33 51.09L69.99 0H56.3L79.69 55.85C81.63 60.51 86.18 63.52 91.25 63.52C96.32 63.52 100.86 60.51 102.81 55.87L126.34 0H112.65ZM121.76 77.84L160.76 38.41L155.44 25.86L112.84 69.01C109.28 72.62 108.26 77.94 110.23 82.6C112.19 87.22 116.72 90.21 121.77 90.21L121.79 90.23L182.68 90.08L177.36 77.53L121.77 77.84H121.76ZM21.92 38.38L27.24 25.83L69.84 68.98C73.4 72.58 74.43 77.92 72.45 82.57C70.49 87.2 65.94 90.18 60.91 90.18L0.02 90.04L0 90.06L5.32 77.51L60.91 77.82L21.92 38.38Z" fill="#6720FF"></path>
|
||||||
|
<path d="M231.32 85.2198L231.33 85.2298H241.8V49.1698H275.62V39.8198H241.8V16.3598H279V7.00977H231.32V85.2198Z" class="fill-black dark:fill-white"></path><path d="M299.68 28.73H289.86V85.22H299.68V28.73Z" class="fill-black dark:fill-white"></path><path d="M324.58 36.2198H324.59C324.37 36.7598 324.16 37.0898 323.5 37.0898C322.95 37.0898 322.74 36.8798 322.74 36.3398V28.7298H312.92V85.2198H322.72V53.1598C322.72 42.3098 327.75 38.0698 337.24 38.0698H345.1V28.5098H338.77C331.03 28.5098 327.1 30.7898 324.58 36.2198Z" class="fill-black dark:fill-white"></path><path d="M377.76 78.3996C367.23 78.3996 359.37 72.4196 358.71 59.7196H404.6V54.2796C404.6 38.5296 395 27.1196 377.53 27.1196C360.06 27.1196 348.93 38.5296 348.93 56.9896C348.93 75.4496 359.73 86.8596 377.74 86.8596C395.75 86.8596 403.15 75.8996 404.81 67.3196H394.57C392.98 73.7396 388.29 78.3996 377.76 78.3996ZM377.53 35.5696C387.91 35.5696 394.33 41.1196 394.78 51.5496H358.98C360.61 40.8896 368.14 35.5696 377.53 35.5696Z" class="fill-black dark:fill-white"></path><path d="M474.29 74.68C474.05 75.66 473.75 75.99 472.97 75.99C472.19 75.99 471.86 75.66 471.65 74.68L460.73 28.73H443.81L432.89 74.68C432.65 75.66 432.35 75.99 431.57 75.99C430.79 75.99 430.46 75.66 430.25 74.68L419.33 28.73H409.73V30.91H409.79L423.11 85.22H439.97L451.22 37.85C451.43 37.08 451.64 36.87 452.3 36.87C452.84 36.87 453.17 37.1 453.38 37.85L464.63 85.22H481.49L494.81 30.91V28.73H485.21L474.29 74.68Z" class="fill-black dark:fill-white"></path><path d="M529.05 27.1099C512.56 27.1099 499.47 37.4199 499.47 56.9799C499.47 76.5399 512.55 86.8499 529.05 86.8499C545.55 86.8499 558.64 76.5399 558.64 56.9799C558.64 37.4199 545.54 27.1099 529.05 27.1099ZM529.07 78.1599C517.61 78.1599 509.42 70.5699 509.42 56.9799C509.42 43.3899 517.61 35.7999 529.07 35.7999C540.53 35.7999 548.72 43.4099 548.72 56.9799C548.72 70.5499 540.53 78.1599 529.07 78.1599Z" class="fill-black dark:fill-white"></path><path d="M580.68 36.2198C580.47 36.7598 580.26 37.0898 579.6 37.0898C579.05 37.0898 578.841 36.8798 578.841 36.3398V28.7298H569.021V85.2098H578.82V53.1598C578.82 42.3098 583.851 38.0698 593.341 38.0698H601.201V28.5098H594.87C587.13 28.5098 583.2 30.7898 580.68 36.2198Z" class="fill-black dark:fill-white"></path><path d="M618.591 55.0198V7.00977H608.771V85.2698H618.591V67.2298L629.24 58.1498L650.42 85.2498H661.16V83.0698L636.49 51.9398L661.16 30.9098V28.7298H648.54L618.591 55.0198Z" class="fill-black dark:fill-white"></path><path d="M695.19 52.8899L687.12 51.3699C679.38 49.8999 675.99 48.2799 675.99 43.5999C675.99 38.9199 679.82 35.4499 688.98 35.4499C698.14 35.4499 703.38 38.9399 704.14 46.6499H714.14C713.03 32.8799 702.34 27.1299 688.94 27.1299C675.54 27.1299 666.13 32.8899 666.13 43.7399C666.13 54.5899 673.83 58.3499 684.91 60.4099L692.98 61.9299C700.84 63.3999 704.77 65.0899 704.77 69.9699C704.77 74.8499 700.83 78.4899 691.35 78.4899C681.87 78.4899 675.58 74.5799 674.82 67.0799H664.83C665.76 80.5499 676.73 86.8499 691.36 86.8499C705.99 86.8499 714.61 80.6099 714.61 69.4099C714.61 58.2099 705.55 54.8399 695.19 52.8899Z" class="fill-black dark:fill-white"></path><path d="M834.64 7.00977H823.63V85.2698H834.64V7.00977Z" class="fill-black dark:fill-white"></path><path d="M770.23 7.77L739.71 83.8398V85.2698H750.61L758.34 64.8398H795.08L802.81 85.2698H814.04V83.8598L783.3 7.00977H770.23ZM761.97 55.3798L775.09 21.0098H775.08C775.3 20.4198 775.87 20.0298 776.5 20.0298H777.04C777.67 20.0298 778.24 20.4198 778.46 21.0098L791.48 55.3798H761.97Z" class="fill-black dark:fill-white"></path><path d="M299.68 7.00977H289.86V18.5298H299.68V7.00977Z" class="fill-black dark:fill-white"></path></svg>
|
After Width: | Height: | Size: 4.2 KiB |
|
@ -0,0 +1,5 @@
|
||||||
|
<svg width="638" height="315" viewBox="0 0 638 315" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path d="M318.563 221.755C300.863 221.755 284.979 211.247 278.206 194.978L196.549 0H244.342L318.842 178.361L393.273 0H441.066L358.92 195.048C352.112 211.247 336.263 221.755 318.563 221.755Z" fill="#6720FF"/>
|
||||||
|
<path d="M425.111 314.933C407.481 314.933 391.667 304.494 384.824 288.366C377.947 272.097 381.507 253.524 393.936 240.921L542.657 90.2803L561.229 134.094L425.076 271.748L619.147 270.666L637.72 314.479L425.146 315.003L425.076 314.933H425.111Z" fill="#6720FF"/>
|
||||||
|
<path d="M0 314.408L18.5727 270.595L212.643 271.677L76.525 133.988L95.0977 90.1748L243.819 240.816C256.247 253.384 259.843 272.026 252.93 288.26C246.088 304.424 230.203 314.827 212.643 314.827L0.0698221 314.339L0 314.408Z" fill="#6720FF"/>
|
||||||
|
</svg>
|
After Width: | Height: | Size: 815 B |
52
api/core/model_runtime/model_providers/fireworks/_common.py
Normal file
52
api/core/model_runtime/model_providers/fireworks/_common.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _CommonFireworks:
|
||||||
|
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
|
||||||
|
"""
|
||||||
|
Transform credentials to kwargs for model instance
|
||||||
|
|
||||||
|
:param credentials:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
credentials_kwargs = {
|
||||||
|
"api_key": credentials["fireworks_api_key"],
|
||||||
|
"base_url": "https://api.fireworks.ai/inference/v1",
|
||||||
|
"max_retries": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
return credentials_kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||||
|
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||||
|
InvokeRateLimitError: [openai.RateLimitError],
|
||||||
|
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
openai.BadRequestError,
|
||||||
|
openai.NotFoundError,
|
||||||
|
openai.UnprocessableEntityError,
|
||||||
|
openai.APIError,
|
||||||
|
],
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FireworksProvider(ModelProvider):
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_instance = self.get_model_instance(ModelType.LLM)
|
||||||
|
model_instance.validate_credentials(
|
||||||
|
model="accounts/fireworks/models/llama-v3p1-8b-instruct", credentials=credentials
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||||
|
raise ex
|
|
@ -0,0 +1,29 @@
|
||||||
|
provider: fireworks
|
||||||
|
label:
|
||||||
|
zh_Hans: Fireworks AI
|
||||||
|
en_US: Fireworks AI
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.svg
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.svg
|
||||||
|
background: "#FCFDFF"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: Get your API Key from Fireworks AI
|
||||||
|
zh_Hans: 从 Fireworks AI 获取 API Key
|
||||||
|
url:
|
||||||
|
en_US: https://fireworks.ai/account/api-keys
|
||||||
|
supported_model_types:
|
||||||
|
- llm
|
||||||
|
configurate_methods:
|
||||||
|
- predefined-model
|
||||||
|
provider_credential_schema:
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: fireworks_api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
|
@ -0,0 +1,16 @@
|
||||||
|
- llama-v3p1-405b-instruct
|
||||||
|
- llama-v3p1-70b-instruct
|
||||||
|
- llama-v3p1-8b-instruct
|
||||||
|
- llama-v3-70b-instruct
|
||||||
|
- mixtral-8x22b-instruct
|
||||||
|
- mixtral-8x7b-instruct
|
||||||
|
- firefunction-v2
|
||||||
|
- firefunction-v1
|
||||||
|
- gemma2-9b-it
|
||||||
|
- llama-v3-70b-instruct-hf
|
||||||
|
- llama-v3-8b-instruct
|
||||||
|
- llama-v3-8b-instruct-hf
|
||||||
|
- mixtral-8x7b-instruct-hf
|
||||||
|
- mythomax-l2-13b
|
||||||
|
- phi-3-vision-128k-instruct
|
||||||
|
- yi-large
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/firefunction-v1
|
||||||
|
label:
|
||||||
|
zh_Hans: Firefunction V1
|
||||||
|
en_US: Firefunction V1
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 32768
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.5'
|
||||||
|
output: '0.5'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/firefunction-v2
|
||||||
|
label:
|
||||||
|
zh_Hans: Firefunction V2
|
||||||
|
en_US: Firefunction V2
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.9'
|
||||||
|
output: '0.9'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,45 @@
|
||||||
|
model: accounts/fireworks/models/gemma2-9b-it
|
||||||
|
label:
|
||||||
|
zh_Hans: Gemma2 9B Instruct
|
||||||
|
en_US: Gemma2 9B Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.2'
|
||||||
|
output: '0.2'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/llama-v3-70b-instruct-hf
|
||||||
|
label:
|
||||||
|
zh_Hans: Llama3 70B Instruct(HF version)
|
||||||
|
en_US: Llama3 70B Instruct(HF version)
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.9'
|
||||||
|
output: '0.9'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/llama-v3-70b-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: Llama3 70B Instruct
|
||||||
|
en_US: Llama3 70B Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.9'
|
||||||
|
output: '0.9'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/llama-v3-8b-instruct-hf
|
||||||
|
label:
|
||||||
|
zh_Hans: Llama3 8B Instruct(HF version)
|
||||||
|
en_US: Llama3 8B Instruct(HF version)
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.2'
|
||||||
|
output: '0.2'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/llama-v3-8b-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: Llama3 8B Instruct
|
||||||
|
en_US: Llama3 8B Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.2'
|
||||||
|
output: '0.2'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: Llama3.1 405B Instruct
|
||||||
|
en_US: Llama3.1 405B Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 131072
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '3'
|
||||||
|
output: '3'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: Llama3.1 70B Instruct
|
||||||
|
en_US: Llama3.1 70B Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 131072
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.2'
|
||||||
|
output: '0.2'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: Llama3.1 8B Instruct
|
||||||
|
en_US: Llama3.1 8B Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 131072
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.2'
|
||||||
|
output: '0.2'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
610
api/core/model_runtime/model_providers/fireworks/llm/llm.py
Normal file
610
api/core/model_runtime/model_providers/fireworks/llm/llm.py
Normal file
|
@ -0,0 +1,610 @@
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
|
from openai import OpenAI, Stream
|
||||||
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
||||||
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
||||||
|
from openai.types.chat.chat_completion_message import FunctionCall
|
||||||
|
|
||||||
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageContentType,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FIREWORKS_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
|
if you are not sure about the structure.
|
||||||
|
|
||||||
|
<instructions>
|
||||||
|
{{instructions}}
|
||||||
|
</instructions>
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
|
||||||
|
"""
|
||||||
|
Model class for Fireworks large language model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._chat_generate(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _code_block_mode_wrapper(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Code block mode wrapper for invoking large language model
|
||||||
|
"""
|
||||||
|
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
|
||||||
|
stop = stop or []
|
||||||
|
self._transform_chat_json_prompts(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
response_format=model_parameters["response_format"],
|
||||||
|
)
|
||||||
|
model_parameters.pop("response_format")
|
||||||
|
|
||||||
|
return self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _transform_chat_json_prompts(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: str | None = None,
|
||||||
|
response_format: str = "JSON",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Transform json prompts
|
||||||
|
"""
|
||||||
|
if stop is None:
|
||||||
|
stop = []
|
||||||
|
if "```\n" not in stop:
|
||||||
|
stop.append("```\n")
|
||||||
|
if "\n```" not in stop:
|
||||||
|
stop.append("\n```")
|
||||||
|
|
||||||
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
|
prompt_messages[0] = SystemPromptMessage(
|
||||||
|
content=FIREWORKS_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
|
||||||
|
"{{block}}", response_format
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
|
||||||
|
else:
|
||||||
|
prompt_messages.insert(
|
||||||
|
0,
|
||||||
|
SystemPromptMessage(
|
||||||
|
content=FIREWORKS_BLOCK_MODE_PROMPT.replace(
|
||||||
|
"{{instructions}}", f"Please output a valid {response_format} object."
|
||||||
|
).replace("{{block}}", response_format)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||||
|
|
||||||
|
def get_num_tokens(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return self._num_tokens_from_messages(model, prompt_messages, tools)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
client = OpenAI(**credentials_kwargs)
|
||||||
|
|
||||||
|
client.chat.completions.create(
|
||||||
|
messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=10, stream=False
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise CredentialsValidateFailedError(str(e))
|
||||||
|
|
||||||
|
def _chat_generate(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
client = OpenAI(**credentials_kwargs)
|
||||||
|
|
||||||
|
extra_model_kwargs = {}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
extra_model_kwargs["functions"] = [
|
||||||
|
{"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools
|
||||||
|
]
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
extra_model_kwargs["stop"] = stop
|
||||||
|
|
||||||
|
if user:
|
||||||
|
extra_model_kwargs["user"] = user
|
||||||
|
|
||||||
|
# chat model
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
**model_parameters,
|
||||||
|
**extra_model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
|
||||||
|
def _handle_chat_generate_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: ChatCompletion,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> LLMResult:
|
||||||
|
"""
|
||||||
|
Handle llm chat response
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials
|
||||||
|
:param response: response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return: llm response
|
||||||
|
"""
|
||||||
|
assistant_message = response.choices[0].message
|
||||||
|
# assistant_message_tool_calls = assistant_message.tool_calls
|
||||||
|
assistant_message_function_call = assistant_message.function_call
|
||||||
|
|
||||||
|
# extract tool calls from response
|
||||||
|
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||||
|
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||||
|
tool_calls = [function_call] if function_call else []
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
|
||||||
|
|
||||||
|
# calculate num tokens
|
||||||
|
if response.usage:
|
||||||
|
# transform usage
|
||||||
|
prompt_tokens = response.usage.prompt_tokens
|
||||||
|
completion_tokens = response.usage.completion_tokens
|
||||||
|
else:
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(model, [assistant_prompt_message])
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
# transform response
|
||||||
|
response = LLMResult(
|
||||||
|
model=response.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
usage=usage,
|
||||||
|
system_fingerprint=response.system_fingerprint,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _handle_chat_generate_stream_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: Stream[ChatCompletionChunk],
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> Generator:
|
||||||
|
"""
|
||||||
|
Handle llm chat stream response
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param response: response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return: llm response chunk generator
|
||||||
|
"""
|
||||||
|
full_assistant_content = ""
|
||||||
|
delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None
|
||||||
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
final_tool_calls = []
|
||||||
|
final_chunk = LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=""),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if len(chunk.choices) == 0:
|
||||||
|
if chunk.usage:
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = chunk.usage.prompt_tokens
|
||||||
|
completion_tokens = chunk.usage.completion_tokens
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk.choices[0]
|
||||||
|
has_finish_reason = delta.finish_reason is not None
|
||||||
|
|
||||||
|
if (
|
||||||
|
not has_finish_reason
|
||||||
|
and (delta.delta.content is None or delta.delta.content == "")
|
||||||
|
and delta.delta.function_call is None
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# assistant_message_tool_calls = delta.delta.tool_calls
|
||||||
|
assistant_message_function_call = delta.delta.function_call
|
||||||
|
|
||||||
|
# extract tool calls from response
|
||||||
|
if delta_assistant_message_function_call_storage is not None:
|
||||||
|
# handle process of stream function call
|
||||||
|
if assistant_message_function_call:
|
||||||
|
# message has not ended ever
|
||||||
|
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# message has ended
|
||||||
|
assistant_message_function_call = delta_assistant_message_function_call_storage
|
||||||
|
delta_assistant_message_function_call_storage = None
|
||||||
|
else:
|
||||||
|
if assistant_message_function_call:
|
||||||
|
# start of stream function call
|
||||||
|
delta_assistant_message_function_call_storage = assistant_message_function_call
|
||||||
|
if delta_assistant_message_function_call_storage.arguments is None:
|
||||||
|
delta_assistant_message_function_call_storage.arguments = ""
|
||||||
|
if not has_finish_reason:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||||
|
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||||
|
tool_calls = [function_call] if function_call else []
|
||||||
|
if tool_calls:
|
||||||
|
final_tool_calls.extend(tool_calls)
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
||||||
|
|
||||||
|
full_assistant_content += delta.delta.content or ""
|
||||||
|
|
||||||
|
if has_finish_reason:
|
||||||
|
final_chunk = LLMResultChunk(
|
||||||
|
model=chunk.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
system_fingerprint=chunk.system_fingerprint,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=delta.index,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
finish_reason=delta.finish_reason,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=chunk.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
system_fingerprint=chunk.system_fingerprint,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=delta.index,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not prompt_tokens:
|
||||||
|
prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
|
||||||
|
|
||||||
|
if not completion_tokens:
|
||||||
|
full_assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=full_assistant_content, tool_calls=final_tool_calls
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message])
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
final_chunk.delta.usage = usage
|
||||||
|
|
||||||
|
yield final_chunk
|
||||||
|
|
||||||
|
def _extract_response_tool_calls(
|
||||||
|
self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]
|
||||||
|
) -> list[AssistantPromptMessage.ToolCall]:
|
||||||
|
"""
|
||||||
|
Extract tool calls from response
|
||||||
|
|
||||||
|
:param response_tool_calls: response tool calls
|
||||||
|
:return: list of tool calls
|
||||||
|
"""
|
||||||
|
tool_calls = []
|
||||||
|
if response_tool_calls:
|
||||||
|
for response_tool_call in response_tool_calls:
|
||||||
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=response_tool_call.id, type=response_tool_call.type, function=function
|
||||||
|
)
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
def _extract_response_function_call(
|
||||||
|
self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall
|
||||||
|
) -> AssistantPromptMessage.ToolCall:
|
||||||
|
"""
|
||||||
|
Extract function call from response
|
||||||
|
|
||||||
|
:param response_function_call: response function call
|
||||||
|
:return: tool call
|
||||||
|
"""
|
||||||
|
tool_call = None
|
||||||
|
if response_function_call:
|
||||||
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=response_function_call.name, arguments=response_function_call.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=response_function_call.name, type="function", function=function
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_call
|
||||||
|
|
||||||
|
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||||
|
"""
|
||||||
|
Convert PromptMessage to dict for Fireworks API
|
||||||
|
"""
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
message = cast(UserPromptMessage, message)
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
else:
|
||||||
|
sub_messages = []
|
||||||
|
for message_content in message.content:
|
||||||
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
|
message_content = cast(TextPromptMessageContent, message_content)
|
||||||
|
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
|
message_content = cast(ImagePromptMessageContent, message_content)
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
|
||||||
|
message_dict = {"role": "user", "content": sub_messages}
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
message = cast(AssistantPromptMessage, message)
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if message.tool_calls:
|
||||||
|
# message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
||||||
|
# message.tool_calls]
|
||||||
|
function_call = message.tool_calls[0]
|
||||||
|
message_dict["function_call"] = {
|
||||||
|
"name": function_call.function.name,
|
||||||
|
"arguments": function_call.function.arguments,
|
||||||
|
}
|
||||||
|
elif isinstance(message, SystemPromptMessage):
|
||||||
|
message = cast(SystemPromptMessage, message)
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
# message_dict = {
|
||||||
|
# "role": "tool",
|
||||||
|
# "content": message.content,
|
||||||
|
# "tool_call_id": message.tool_call_id
|
||||||
|
# }
|
||||||
|
message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
if message.name:
|
||||||
|
message_dict["name"] = message.name
|
||||||
|
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
def _num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
credentials: dict = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Approximate num tokens with GPT2 tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokens_per_message = 3
|
||||||
|
tokens_per_name = 1
|
||||||
|
|
||||||
|
num_tokens = 0
|
||||||
|
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||||
|
for message in messages_dict:
|
||||||
|
num_tokens += tokens_per_message
|
||||||
|
for key, value in message.items():
|
||||||
|
# Cast str(value) in case the message value is not a string
|
||||||
|
# This occurs with function messages
|
||||||
|
# TODO: The current token calculation method for the image type is not implemented,
|
||||||
|
# which need to download the image and then get the resolution for calculation,
|
||||||
|
# and will increase the request delay
|
||||||
|
if isinstance(value, list):
|
||||||
|
text = ""
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict) and item["type"] == "text":
|
||||||
|
text += item["text"]
|
||||||
|
|
||||||
|
value = text
|
||||||
|
|
||||||
|
if key == "tool_calls":
|
||||||
|
for tool_call in value:
|
||||||
|
for t_key, t_value in tool_call.items():
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(t_key)
|
||||||
|
if t_key == "function":
|
||||||
|
for f_key, f_value in t_value.items():
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(f_key)
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(f_value)
|
||||||
|
else:
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(t_key)
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(t_value)
|
||||||
|
else:
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(str(value))
|
||||||
|
|
||||||
|
if key == "name":
|
||||||
|
num_tokens += tokens_per_name
|
||||||
|
|
||||||
|
# every reply is primed with <im_start>assistant
|
||||||
|
num_tokens += 3
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
num_tokens += self._num_tokens_for_tools(tools)
|
||||||
|
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
||||||
|
"""
|
||||||
|
Calculate num tokens for tool calling with tiktoken package.
|
||||||
|
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return: number of tokens
|
||||||
|
"""
|
||||||
|
num_tokens = 0
|
||||||
|
for tool in tools:
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2("type")
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2("function")
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2("function")
|
||||||
|
|
||||||
|
# calculate num tokens for function object
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2("name")
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(tool.name)
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2("description")
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(tool.description)
|
||||||
|
parameters = tool.parameters
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2("parameters")
|
||||||
|
if "title" in parameters:
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2("title")
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title"))
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2("type")
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type"))
|
||||||
|
if "properties" in parameters:
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2("properties")
|
||||||
|
for key, value in parameters.get("properties").items():
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(key)
|
||||||
|
for field_key, field_value in value.items():
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
||||||
|
if field_key == "enum":
|
||||||
|
for enum_field in field_value:
|
||||||
|
num_tokens += 3
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(enum_field)
|
||||||
|
else:
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(str(field_value))
|
||||||
|
if "required" in parameters:
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2("required")
|
||||||
|
for required_field in parameters["required"]:
|
||||||
|
num_tokens += 3
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
||||||
|
|
||||||
|
return num_tokens
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/mixtral-8x22b-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: Mixtral MoE 8x22B Instruct
|
||||||
|
en_US: Mixtral MoE 8x22B Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 65536
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '1.2'
|
||||||
|
output: '1.2'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/mixtral-8x7b-instruct-hf
|
||||||
|
label:
|
||||||
|
zh_Hans: Mixtral MoE 8x7B Instruct(HF version)
|
||||||
|
en_US: Mixtral MoE 8x7B Instruct(HF version)
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 32768
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.5'
|
||||||
|
output: '0.5'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/mixtral-8x7b-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: Mixtral MoE 8x7B Instruct
|
||||||
|
en_US: Mixtral MoE 8x7B Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 32768
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.5'
|
||||||
|
output: '0.5'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/mythomax-l2-13b
|
||||||
|
label:
|
||||||
|
zh_Hans: MythoMax L2 13b
|
||||||
|
en_US: MythoMax L2 13b
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 4096
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.2'
|
||||||
|
output: '0.2'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/phi-3-vision-128k-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: Phi3.5 Vision Instruct
|
||||||
|
en_US: Phi3.5 Vision Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.2'
|
||||||
|
output: '0.2'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,45 @@
|
||||||
|
model: accounts/yi-01-ai/models/yi-large
|
||||||
|
label:
|
||||||
|
zh_Hans: Yi-Large
|
||||||
|
en_US: Yi-Large
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 32768
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '3'
|
||||||
|
output: '3'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -100,6 +100,7 @@ exclude = [
|
||||||
[tool.pytest_env]
|
[tool.pytest_env]
|
||||||
OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii"
|
OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii"
|
||||||
UPSTAGE_API_KEY = "up-aaaaaaaaaaaaaaaaaaaa"
|
UPSTAGE_API_KEY = "up-aaaaaaaaaaaaaaaaaaaa"
|
||||||
|
FIREWORKS_API_KEY = "fw_aaaaaaaaaaaaaaaaaaaa"
|
||||||
AZURE_OPENAI_API_BASE = "https://difyai-openai.openai.azure.com"
|
AZURE_OPENAI_API_BASE = "https://difyai-openai.openai.azure.com"
|
||||||
AZURE_OPENAI_API_KEY = "xxxxb1707exxxxxxxxxxaaxxxxxf94"
|
AZURE_OPENAI_API_KEY = "xxxxb1707exxxxxxxxxxaaxxxxxf94"
|
||||||
ANTHROPIC_API_KEY = "sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz"
|
ANTHROPIC_API_KEY = "sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz"
|
||||||
|
|
186
api/tests/integration_tests/model_runtime/fireworks/test_llm.py
Normal file
186
api/tests/integration_tests/model_runtime/fireworks/test_llm.py
Normal file
|
@ -0,0 +1,186 @@
|
||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.fireworks.llm.llm import FireworksLargeLanguageModel
|
||||||
|
|
||||||
|
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||||
|
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_predefined_models():
|
||||||
|
model = FireworksLargeLanguageModel()
|
||||||
|
model_schemas = model.predefined_models()
|
||||||
|
|
||||||
|
assert len(model_schemas) >= 1
|
||||||
|
assert isinstance(model_schemas[0], AIModelEntity)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||||
|
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||||
|
model = FireworksLargeLanguageModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
# model name to gpt-3.5-turbo because of mocking
|
||||||
|
model.validate_credentials(model="gpt-3.5-turbo", credentials={"fireworks_api_key": "invalid_key"})
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||||
|
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||||
|
def test_invoke_chat_model(setup_openai_mock):
|
||||||
|
model = FireworksLargeLanguageModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||||
|
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(content="Hello World!"),
|
||||||
|
],
|
||||||
|
model_parameters={
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
stop=["How"],
|
||||||
|
stream=False,
|
||||||
|
user="foo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
assert len(result.message.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||||
|
def test_invoke_chat_model_with_tools(setup_openai_mock):
|
||||||
|
model = FireworksLargeLanguageModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||||
|
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(
|
||||||
|
content="what's the weather today in London?",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||||
|
tools=[
|
||||||
|
PromptMessageTool(
|
||||||
|
name="get_weather",
|
||||||
|
description="Determine weather in my location",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||||
|
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
PromptMessageTool(
|
||||||
|
name="get_stock_price",
|
||||||
|
description="Get the current stock price",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
|
||||||
|
"required": ["symbol"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
user="foo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
assert isinstance(result.message, AssistantPromptMessage)
|
||||||
|
assert len(result.message.tool_calls) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||||
|
def test_invoke_stream_chat_model(setup_openai_mock):
|
||||||
|
model = FireworksLargeLanguageModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||||
|
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(content="Hello World!"),
|
||||||
|
],
|
||||||
|
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||||
|
stream=True,
|
||||||
|
user="foo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, Generator)
|
||||||
|
|
||||||
|
for chunk in result:
|
||||||
|
assert isinstance(chunk, LLMResultChunk)
|
||||||
|
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||||
|
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||||
|
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||||
|
if chunk.delta.finish_reason is not None:
|
||||||
|
assert chunk.delta.usage is not None
|
||||||
|
assert chunk.delta.usage.completion_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = FireworksLargeLanguageModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||||
|
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||||
|
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 10
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||||
|
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(content="Hello World!"),
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
PromptMessageTool(
|
||||||
|
name="get_weather",
|
||||||
|
description="Determine weather in my location",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||||
|
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 77
|
|
@ -0,0 +1,17 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.fireworks.fireworks import FireworksProvider
|
||||||
|
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||||
|
def test_validate_provider_credentials(setup_openai_mock):
|
||||||
|
provider = FireworksProvider()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
provider.validate_provider_credentials(credentials={})
|
||||||
|
|
||||||
|
provider.validate_provider_credentials(credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")})
|
|
@ -6,5 +6,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \
|
||||||
api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \
|
api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \
|
||||||
api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \
|
api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \
|
||||||
api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \
|
api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \
|
||||||
api/tests/integration_tests/model_runtime/upstage
|
api/tests/integration_tests/model_runtime/upstage \
|
||||||
|
api/tests/integration_tests/model_runtime/fireworks
|
||||||
|
|
Loading…
Reference in New Issue
Block a user