mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Compare commits
33 Commits
3f10f1f84b
...
a3562af5fc
Author | SHA1 | Date | |
---|---|---|---|
|
a3562af5fc | ||
|
03cf2c5ce9 | ||
|
5ff02b469f | ||
|
432272c9c6 | ||
|
8c1a20ed5e | ||
|
39d0da72a0 | ||
|
44f57ad9a8 | ||
|
94fd6f6901 | ||
|
e61242a337 | ||
|
e4be3a4ec7 | ||
|
79dfe6f5b1 | ||
|
81a32cede0 | ||
|
9ef0c2b582 | ||
|
1bb685d19e | ||
|
b703e2528c | ||
|
52ae57dc4e | ||
|
722964667f | ||
|
fbb9c1c249 | ||
|
15f341b655 | ||
|
b358490607 | ||
|
f9e4196fd5 | ||
|
751525802d | ||
|
2abacd2a2d | ||
|
a3155e0613 | ||
|
fb23e759c7 | ||
|
7fe986df8a | ||
|
70b9e4caf5 | ||
|
93b6164a19 | ||
|
317ae9233e | ||
|
5b8f03cd9d | ||
|
2a4783307a | ||
|
bddecba9ed | ||
|
931e76e3d1 |
54
.github/pull_request_template.md
vendored
54
.github/pull_request_template.md
vendored
|
@ -1,34 +1,32 @@
|
|||
# Checklist:
|
||||
# Summary
|
||||
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
|
||||
|
||||
> [!Tip]
|
||||
> Close issue syntax: `Fixes #<issue number>` or `Resolves #<issue number>`, see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) for more details.
|
||||
|
||||
|
||||
# Screenshots
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td>Before: </td>
|
||||
<td>After: </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>...</td>
|
||||
<td>...</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
# Checklist
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Please review the checklist below before submitting your pull request.
|
||||
|
||||
- [ ] Please open an issue before creating a PR or link to an existing issue
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
|
||||
# Description
|
||||
|
||||
Describe the big picture of your changes here to communicate to the maintainers why we should accept this pull request. If it fixes a bug or resolves a feature request, be sure to link to that issue. Close issue syntax: `Fixes #<issue number>`, see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) for more details.
|
||||
|
||||
Fixes
|
||||
|
||||
## Type of Change
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
- [ ] Improvement, including but not limited to code refactoring, performance optimization, and UI/UX improvement
|
||||
- [ ] Dependency upgrade
|
||||
|
||||
# Testing Instructions
|
||||
|
||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
|
||||
|
||||
- [ ] Test A
|
||||
- [ ] Test B
|
||||
|
||||
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ RUN apt-get update \
|
|||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-7 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-7 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
&& apt-get autoremove -y \
|
||||
|
|
|
@ -217,6 +217,7 @@ class WorkflowCycleManage:
|
|||
).total_seconds()
|
||||
db.session.commit()
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
|
|
|
@ -74,6 +74,8 @@ def to_prompt_message_content(
|
|||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case _:
|
||||
raise ValueError("file type f.type is not supported")
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
|
@ -27,7 +28,7 @@ class TokenBufferMemory:
|
|||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> list[PromptMessage]:
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
|
|
|
@ -100,7 +100,7 @@ class ModelInstance:
|
|||
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
|
|
|
@ -58,6 +58,7 @@ class PromptMessageContentType(Enum):
|
|||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
DOCUMENT = "document"
|
||||
|
||||
|
||||
class PromptMessageContent(BaseModel):
|
||||
|
|
|
@ -87,6 +87,9 @@ class ModelFeature(Enum):
|
|||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
class DefaultParameterName(str, Enum):
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
model: abab7-chat-preview
|
||||
label:
|
||||
en_US: Abab7-chat-preview
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 245760
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.95
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 245760
|
||||
- name: mask_sensitive_info
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
zh_Hans: 隐私保护
|
||||
en_US: Moderate
|
||||
help:
|
||||
zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
|
||||
en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.1'
|
||||
output: '0.1'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
|
@ -34,6 +34,7 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
|||
|
||||
class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
model_apis = {
|
||||
"abab7-chat-preview": MinimaxChatCompletionPro,
|
||||
"abab6.5s-chat": MinimaxChatCompletionPro,
|
||||
"abab6.5-chat": MinimaxChatCompletionPro,
|
||||
"abab6-chat": MinimaxChatCompletionPro,
|
||||
|
|
|
@ -8,6 +8,7 @@ features:
|
|||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
model: OpenGVLab/InternVL2-26B
|
||||
label:
|
||||
en_US: OpenGVLab/InternVL2-26B
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
zh_Hans: 重复惩罚
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '21'
|
||||
output: '21'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
|
@ -0,0 +1,84 @@
|
|||
model: Pro/OpenGVLab/InternVL2-8B
|
||||
label:
|
||||
en_US: Pro/OpenGVLab/InternVL2-8B
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
zh_Hans: 重复惩罚
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '21'
|
||||
output: '21'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
|
@ -1,16 +1,18 @@
|
|||
- Tencent/Hunyuan-A52B-Instruct
|
||||
- Qwen/Qwen2.5-72B-Instruct
|
||||
- Qwen/Qwen2.5-32B-Instruct
|
||||
- Qwen/Qwen2.5-14B-Instruct
|
||||
- Qwen/Qwen2.5-7B-Instruct
|
||||
- Qwen/Qwen2.5-Coder-32B-Instruct
|
||||
- Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
- Qwen/Qwen2.5-Math-72B-Instruct
|
||||
- Qwen/Qwen2-72B-Instruct
|
||||
- Qwen/Qwen2-57B-A14B-Instruct
|
||||
- Qwen/Qwen2-7B-Instruct
|
||||
- Qwen/Qwen2-VL-72B-Instruct
|
||||
- Qwen/Qwen2-1.5B-Instruct
|
||||
- Pro/Qwen/Qwen2-VL-7B-Instruct
|
||||
- OpenGVLab/InternVL2-Llama3-76B
|
||||
- OpenGVLab/InternVL2-26B
|
||||
- Pro/OpenGVLab/InternVL2-8B
|
||||
- deepseek-ai/DeepSeek-V2.5
|
||||
- deepseek-ai/DeepSeek-V2-Chat
|
||||
- deepseek-ai/DeepSeek-Coder-V2-Instruct
|
||||
- THUDM/glm-4-9b-chat
|
||||
- 01-ai/Yi-1.5-34B-Chat-16K
|
||||
- 01-ai/Yi-1.5-9B-Chat-16K
|
||||
|
@ -20,9 +22,6 @@
|
|||
- meta-llama/Meta-Llama-3.1-405B-Instruct
|
||||
- meta-llama/Meta-Llama-3.1-70B-Instruct
|
||||
- meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
- meta-llama/Meta-Llama-3-70B-Instruct
|
||||
- meta-llama/Meta-Llama-3-8B-Instruct
|
||||
- google/gemma-2-27b-it
|
||||
- google/gemma-2-9b-it
|
||||
- mistralai/Mistral-7B-Instruct-v0.2
|
||||
- mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||
- deepseek-ai/DeepSeek-V2-Chat
|
||||
|
|
|
@ -37,3 +37,4 @@ pricing:
|
|||
output: '1.33'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
|
|
@ -37,3 +37,4 @@ pricing:
|
|||
output: '1.33'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
|
|
@ -4,6 +4,8 @@ label:
|
|||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
|
@ -32,6 +34,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1.33'
|
||||
output: '1.33'
|
||||
|
|
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1.26'
|
||||
output: '1.26'
|
||||
|
|
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
|
|
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
model: Tencent/Hunyuan-A52B-Instruct
|
||||
label:
|
||||
en_US: Tencent/Hunyuan-A52B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
zh_Hans: 重复惩罚
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '21'
|
||||
output: '21'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1'
|
||||
output: '1'
|
||||
|
|
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
model: OpenGVLab/InternVL2-Llama3-76B
|
||||
label:
|
||||
en_US: OpenGVLab/InternVL2-Llama3-76B
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
zh_Hans: 重复惩罚
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '21'
|
||||
output: '21'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
|
@ -29,6 +29,9 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
|
|
@ -37,3 +37,4 @@ pricing:
|
|||
output: '4.13'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
|
|
@ -37,3 +37,4 @@ pricing:
|
|||
output: '0'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
|
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '21'
|
||||
output: '21'
|
||||
|
|
|
@ -6,7 +6,7 @@ features:
|
|||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '4.13'
|
||||
output: '4.13'
|
||||
|
|
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
|
|
|
@ -37,3 +37,4 @@ pricing:
|
|||
output: '1.26'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
|
|
@ -37,3 +37,4 @@ pricing:
|
|||
output: '4.13'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
|
|
@ -37,3 +37,4 @@ pricing:
|
|||
output: '0'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
model: Qwen/Qwen2-VL-72B-Instruct
|
||||
label:
|
||||
en_US: Qwen/Qwen2-VL-72B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
zh_Hans: 重复惩罚
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '21'
|
||||
output: '21'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
|
@ -0,0 +1,84 @@
|
|||
model: Pro/Qwen/Qwen2-VL-7B-Instruct
|
||||
label:
|
||||
en_US: Pro/Qwen/Qwen2-VL-7B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
zh_Hans: 重复惩罚
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '21'
|
||||
output: '21'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.7'
|
||||
output: '0.7'
|
||||
|
|
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1.26'
|
||||
output: '1.26'
|
||||
|
|
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '4.13'
|
||||
output: '4.13'
|
||||
|
|
|
@ -32,6 +32,18 @@ parameter_rules:
|
|||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
model: Qwen/Qwen2.5-Coder-32B-Instruct
|
||||
label:
|
||||
en_US: Qwen/Qwen2.5-Coder-32B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: seed
|
||||
required: false
|
||||
type: int
|
||||
default: 1234
|
||||
label:
|
||||
zh_Hans: 随机种子
|
||||
en_US: Random seed
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
zh_Hans: 重复惩罚
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1.26'
|
||||
output: '1.26'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
|
@ -66,7 +66,17 @@ parameter_rules:
|
|||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
|
|
|
@ -66,7 +66,17 @@ parameter_rules:
|
|||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '4.13'
|
||||
output: '4.13'
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
model: FunAudioLLM/SenseVoiceSmall
|
||||
model_type: speech2text
|
||||
model_properties:
|
||||
file_upload_limit: 1
|
||||
supported_file_extensions: mp3,wav
|
|
@ -3,3 +3,4 @@ model_type: speech2text
|
|||
model_properties:
|
||||
file_upload_limit: 1
|
||||
supported_file_extensions: mp3,wav
|
||||
deprecated: true
|
||||
|
|
|
@ -78,3 +78,4 @@
|
|||
- regex
|
||||
- trello
|
||||
- vanna
|
||||
- fal
|
||||
|
|
3
api/core/tools/provider/builtin/audio/_assets/icon.svg
Normal file
3
api/core/tools/provider/builtin/audio/_assets/icon.svg
Normal file
|
@ -0,0 +1,3 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="200" height="200" viewBox="0 0 200 200" fill="none">
|
||||
<path d="M167.358 102.395C167.358 117.174 157.246 129.18 144.61 131.027H137.861C125.225 129.18 115.113 117.174 115.113 102.395H100.792C100.792 123.637 115.118 142.106 133.653 145.801V164.276H147.139V145.801C165.674 142.106 180 124.558 180 102.4H167.358V102.395ZM154.717 62.677C154.717 53.4397 147.979 46.9765 140.396 46.9765C138.523 46.9446 136.663 47.3273 134.924 48.1024C133.185 48.8775 131.603 50.0294 130.27 51.4909C128.936 52.9524 127.878 54.6943 127.157 56.6148C126.436 58.5354 126.066 60.5962 126.07 62.677V78.3775H154.717V70.4478V62.677ZM126.07 102.395C126.07 111.632 132.813 118.095 140.396 118.095C142.269 118.127 144.13 117.744 145.868 116.969C147.607 116.194 149.189 115.042 150.523 113.581C151.856 112.119 152.914 110.377 153.635 108.457C154.356 106.536 154.726 104.475 154.722 102.395V86.694H126.07V102.395ZM92.1297 45.8938L70.4796 21.7595L69.4235 20.5865L59.604 20L68.3674 20.5865L67.3113 21.7654L64.1429 25.2961L63.6149 25.8826L64.1429 27.0614L66.2552 29.4133L77.8723 42.3631H54.1099C35.1 43.5361 20.3146 61.1896 20.3146 81.7874V83.5527H28.2354V81.7932C28.2354 65.8992 39.8525 52.3628 54.1099 51.1899H77.8723L66.2552 64.1338L64.671 65.8992L64.1429 67.0722L63.6149 67.6645L64.1429 68.251L68.3674 72.9606L68.8954 73.5471L69.4235 72.9606L74.1759 67.6645L92.1297 47.6591L92.6578 47.0727L92.1297 45.8938ZM20 95.8496V118.213H30.033V107.034H50.099V168.821H40.066V180H70.165V168.821H60.132V107.034H80.198V118.213H90.231V95.8496H20Z" fill="#FF0099"/>
|
||||
</svg>
|
After Width: | Height: | Size: 1.5 KiB |
6
api/core/tools/provider/builtin/audio/audio.py
Normal file
6
api/core/tools/provider/builtin/audio/audio.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AudioToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
pass
|
11
api/core/tools/provider/builtin/audio/audio.yaml
Normal file
11
api/core/tools/provider/builtin/audio/audio.yaml
Normal file
|
@ -0,0 +1,11 @@
|
|||
identity:
|
||||
author: hjlarry
|
||||
name: audio
|
||||
label:
|
||||
en_US: Audio
|
||||
description:
|
||||
en_US: A tool for tts and asr.
|
||||
zh_Hans: 一个用于文本转语音和语音转文本的工具。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- utilities
|
70
api/core/tools/provider/builtin/audio/tools/asr.py
Normal file
70
api/core/tools/provider/builtin/audio/tools/asr.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import io
|
||||
from typing import Any
|
||||
|
||||
from core.file.enums import FileType
|
||||
from core.file.file_manager import download
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class ASRTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
file = tool_parameters.get("audio_file")
|
||||
if file.type != FileType.AUDIO:
|
||||
return [self.create_text_message("not a valid audio file")]
|
||||
audio_binary = io.BytesIO(download(file))
|
||||
audio_binary.name = "temp.mp3"
|
||||
provider, model = tool_parameters.get("model").split("#")
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
provider=provider,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model=model,
|
||||
)
|
||||
text = model_instance.invoke_speech2text(
|
||||
file=audio_binary,
|
||||
user=user_id,
|
||||
)
|
||||
return [self.create_text_message(text)]
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str]]:
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(
|
||||
tenant_id=self.runtime.tenant_id, model_type="speech2text"
|
||||
)
|
||||
items = []
|
||||
for provider_model in models:
|
||||
provider = provider_model.provider
|
||||
for model in provider_model.models:
|
||||
items.append((provider, model.model))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
for provider, model in self.get_available_models():
|
||||
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
options.append(option)
|
||||
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name="model",
|
||||
label=I18nObject(en_US="Model", zh_Hans="Model"),
|
||||
human_description=I18nObject(
|
||||
en_US="All available ASR models",
|
||||
zh_Hans="所有可用的 ASR 模型",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=True,
|
||||
default=options[0].value,
|
||||
options=options,
|
||||
)
|
||||
)
|
||||
return parameters
|
22
api/core/tools/provider/builtin/audio/tools/asr.yaml
Normal file
22
api/core/tools/provider/builtin/audio/tools/asr.yaml
Normal file
|
@ -0,0 +1,22 @@
|
|||
identity:
|
||||
name: asr
|
||||
author: hjlarry
|
||||
label:
|
||||
en_US: Speech To Text
|
||||
description:
|
||||
human:
|
||||
en_US: Convert audio file to text.
|
||||
zh_Hans: 将音频文件转换为文本。
|
||||
llm: Convert audio file to text.
|
||||
parameters:
|
||||
- name: audio_file
|
||||
type: file
|
||||
required: true
|
||||
label:
|
||||
en_US: Audio File
|
||||
zh_Hans: 音频文件
|
||||
human_description:
|
||||
en_US: The audio file to be converted.
|
||||
zh_Hans: 要转换的音频文件。
|
||||
llm_description: The audio file to be converted.
|
||||
form: llm
|
90
api/core/tools/provider/builtin/audio/tools/tts.py
Normal file
90
api/core/tools/provider/builtin/audio/tools/tts.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
import io
|
||||
from typing import Any
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class TTSTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
provider, model = tool_parameters.get("model").split("#")
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
provider=provider,
|
||||
model_type=ModelType.TTS,
|
||||
model=model,
|
||||
)
|
||||
tts = model_instance.invoke_tts(
|
||||
content_text=tool_parameters.get("text"),
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
voice=voice,
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
for chunk in tts:
|
||||
buffer.write(chunk)
|
||||
|
||||
wav_bytes = buffer.getvalue()
|
||||
return [
|
||||
self.create_text_message("Audio generated successfully"),
|
||||
self.create_blob_message(
|
||||
blob=wav_bytes,
|
||||
meta={"mime_type": "audio/x-wav"},
|
||||
save_as=self.VariableKey.AUDIO,
|
||||
),
|
||||
]
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
|
||||
items = []
|
||||
for provider_model in models:
|
||||
provider = provider_model.provider
|
||||
for model in provider_model.models:
|
||||
voices = model.model_properties.get(ModelPropertyKey.VOICES, [])
|
||||
items.append((provider, model.model, voices))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
for provider, model, voices in self.get_available_models():
|
||||
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
options.append(option)
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=f"voice#{provider}#{model}",
|
||||
label=I18nObject(en_US=f"Voice of {model}({provider})"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
options=[
|
||||
ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
|
||||
for voice in voices
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
parameters.insert(
|
||||
0,
|
||||
ToolParameter(
|
||||
name="model",
|
||||
label=I18nObject(en_US="Model", zh_Hans="Model"),
|
||||
human_description=I18nObject(
|
||||
en_US="All available TTS models",
|
||||
zh_Hans="所有可用的 TTS 模型",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=True,
|
||||
default=options[0].value,
|
||||
options=options,
|
||||
),
|
||||
)
|
||||
return parameters
|
22
api/core/tools/provider/builtin/audio/tools/tts.yaml
Normal file
22
api/core/tools/provider/builtin/audio/tools/tts.yaml
Normal file
|
@ -0,0 +1,22 @@
|
|||
identity:
|
||||
name: tts
|
||||
author: hjlarry
|
||||
label:
|
||||
en_US: Text To Speech
|
||||
description:
|
||||
human:
|
||||
en_US: Convert text to audio file.
|
||||
zh_Hans: 将文本转换为音频文件。
|
||||
llm: Convert text to audio file.
|
||||
parameters:
|
||||
- name: text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Text
|
||||
zh_Hans: 文本
|
||||
human_description:
|
||||
en_US: The text to be converted.
|
||||
zh_Hans: 要转换的文本。
|
||||
llm_description: The text to be converted.
|
||||
form: llm
|
4
api/core/tools/provider/builtin/fal/_assets/icon.svg
Normal file
4
api/core/tools/provider/builtin/fal/_assets/icon.svg
Normal file
|
@ -0,0 +1,4 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg version="1.1" xmlns="http://www.w3.org/2000/svg" width="32" height="32">
|
||||
<path d="M0 0 C3.96 0 7.92 0 12 0 C12.4125 0.928125 12.825 1.85625 13.25 2.8125 C15.56104487 7.02190315 17.49701732 8.49900577 22 10 C22 13.96 22 17.92 22 22 C21.071875 22.4125 20.14375 22.825 19.1875 23.25 C14.97809685 25.56104487 13.50099423 27.49701732 12 32 C8.04 32 4.08 32 0 32 C-0.4125 31.071875 -0.825 30.14375 -1.25 29.1875 C-3.56104487 24.97809685 -5.49701732 23.50099423 -10 22 C-10 18.04 -10 14.08 -10 10 C-9.071875 9.5875 -8.14375 9.175 -7.1875 8.75 C-2.97809685 6.43895513 -1.50099423 4.50298268 0 0 Z M-2 11 C-3.42662219 13.85324437 -3.31033868 15.83454549 -3 19 C-1.20006226 21.69990662 0.083773 23.5418865 3 25 C7.1364408 25.56406011 8.76045933 25.14638597 12.375 22.9375 C15.26054626 20.20817124 15.26054626 20.20817124 15.6875 16.5625 C14.76325283 11.77321919 13.68514918 10.2147046 10 7 C4.54838272 6.02649691 1.87056683 7.12943317 -2 11 Z " fill="#EC0648" transform="translate(10,0)"/>
|
||||
</svg>
|
After Width: | Height: | Size: 1.0 KiB |
20
api/core/tools/provider/builtin/fal/fal.py
Normal file
20
api/core/tools/provider/builtin/fal/fal.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import requests
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class FalProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
url = "https://fal.run/fal-ai/flux/dev"
|
||||
headers = {
|
||||
"Authorization": f"Key {credentials.get('fal_api_key')}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {"prompt": "Cat"}
|
||||
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
if response.status_code == 401:
|
||||
raise ToolProviderCredentialValidationError("FAL API key is invalid")
|
||||
elif response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError(f"FAL API key validation failed: {response.text}")
|
21
api/core/tools/provider/builtin/fal/fal.yaml
Normal file
21
api/core/tools/provider/builtin/fal/fal.yaml
Normal file
|
@ -0,0 +1,21 @@
|
|||
identity:
|
||||
author: Kalo Chin
|
||||
name: fal
|
||||
label:
|
||||
en_US: FAL
|
||||
zh_CN: FAL
|
||||
description:
|
||||
en_US: The image generation API provided by FAL.
|
||||
zh_CN: FAL 提供的图像生成 API。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- image
|
||||
credentials_for_provider:
|
||||
fal_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: FAL API Key
|
||||
placeholder:
|
||||
en_US: Please input your FAL API key
|
||||
url: https://fal.ai/dashboard/keys
|
46
api/core/tools/provider/builtin/fal/tools/flux_1_1_pro.py
Normal file
46
api/core/tools/provider/builtin/fal/tools/flux_1_1_pro.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class Flux11ProTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"Authorization": f"Key {self.runtime.credentials['fal_api_key']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
sanitized_prompt = prompt.replace("\\", "") # Remove backslashes from the prompt which may cause errors
|
||||
|
||||
payload = {
|
||||
"prompt": sanitized_prompt,
|
||||
"image_size": tool_parameters.get("image_size", "landscape_4_3"),
|
||||
"seed": tool_parameters.get("seed"),
|
||||
"sync_mode": tool_parameters.get("sync_mode", False),
|
||||
"num_images": tool_parameters.get("num_images", 1),
|
||||
"enable_safety_checker": tool_parameters.get("enable_safety_checker", True),
|
||||
"safety_tolerance": tool_parameters.get("safety_tolerance", "2"),
|
||||
}
|
||||
|
||||
url = "https://fal.run/fal-ai/flux-pro/v1.1"
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response: {response.text}")
|
||||
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
|
||||
for image_info in res.get("images", []):
|
||||
image_url = image_info.get("url")
|
||||
if image_url:
|
||||
result.append(self.create_image_message(image=image_url, save_as=self.VariableKey.IMAGE.value))
|
||||
|
||||
return result
|
147
api/core/tools/provider/builtin/fal/tools/flux_1_1_pro.yaml
Normal file
147
api/core/tools/provider/builtin/fal/tools/flux_1_1_pro.yaml
Normal file
|
@ -0,0 +1,147 @@
|
|||
identity:
|
||||
name: flux_1_1_pro
|
||||
author: Kalo Chin
|
||||
label:
|
||||
en_US: FLUX 1.1 [pro]
|
||||
zh_Hans: FLUX 1.1 [pro]
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: FLUX 1.1 [pro] is an enhanced version of FLUX.1 [pro], improved image generation capabilities, delivering superior composition, detail, and artistic fidelity compared to its predecessor.
|
||||
zh_Hans: FLUX 1.1 [pro] 是 FLUX.1 [pro] 的增强版,改进了图像生成能力,与其前身相比,提供了更出色的构图、细节和艺术保真度。
|
||||
llm: This tool generates images from prompts using FAL's FLUX 1.1 [pro] model.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: The text prompt used to generate the image.
|
||||
zh_Hans: 用于生成图片的文字提示词。
|
||||
llm_description: This prompt text will be used to generate the image.
|
||||
form: llm
|
||||
- name: image_size
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: square_hd
|
||||
label:
|
||||
en_US: Square HD
|
||||
zh_Hans: 方形高清
|
||||
- value: square
|
||||
label:
|
||||
en_US: Square
|
||||
zh_Hans: 方形
|
||||
- value: portrait_4_3
|
||||
label:
|
||||
en_US: Portrait 4:3
|
||||
zh_Hans: 竖屏 4:3
|
||||
- value: portrait_16_9
|
||||
label:
|
||||
en_US: Portrait 16:9
|
||||
zh_Hans: 竖屏 16:9
|
||||
- value: landscape_4_3
|
||||
label:
|
||||
en_US: Landscape 4:3
|
||||
zh_Hans: 横屏 4:3
|
||||
- value: landscape_16_9
|
||||
label:
|
||||
en_US: Landscape 16:9
|
||||
zh_Hans: 横屏 16:9
|
||||
default: landscape_4_3
|
||||
label:
|
||||
en_US: Image Size
|
||||
zh_Hans: 图片大小
|
||||
human_description:
|
||||
en_US: The size of the generated image.
|
||||
zh_Hans: 生成图像的尺寸。
|
||||
form: form
|
||||
- name: num_images
|
||||
type: number
|
||||
required: false
|
||||
default: 1
|
||||
min: 1
|
||||
max: 1
|
||||
label:
|
||||
en_US: Number of Images
|
||||
zh_Hans: 图片数量
|
||||
human_description:
|
||||
en_US: The number of images to generate.
|
||||
zh_Hans: 要生成的图片数量。
|
||||
form: form
|
||||
- name: safety_tolerance
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: "1"
|
||||
label:
|
||||
en_US: "1 (Most strict)"
|
||||
zh_Hans: "1(最严格)"
|
||||
- value: "2"
|
||||
label:
|
||||
en_US: "2"
|
||||
zh_Hans: "2"
|
||||
- value: "3"
|
||||
label:
|
||||
en_US: "3"
|
||||
zh_Hans: "3"
|
||||
- value: "4"
|
||||
label:
|
||||
en_US: "4"
|
||||
zh_Hans: "4"
|
||||
- value: "5"
|
||||
label:
|
||||
en_US: "5"
|
||||
zh_Hans: "5"
|
||||
- value: "6"
|
||||
label:
|
||||
en_US: "6 (Most permissive)"
|
||||
zh_Hans: "6(最宽松)"
|
||||
default: "2"
|
||||
label:
|
||||
en_US: Safety Tolerance
|
||||
zh_Hans: 安全容忍度
|
||||
human_description:
|
||||
en_US: The safety tolerance level for the generated image. 1 being the most strict and 6 being the most permissive.
|
||||
zh_Hans: 生成图像的安全容忍级别,1 为最严格,6 为最宽松。
|
||||
form: form
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
min: 0
|
||||
max: 9999999999
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子
|
||||
human_description:
|
||||
en_US: The same seed and prompt can produce similar images.
|
||||
zh_Hans: 相同的种子和提示词可以产生相似的图像。
|
||||
form: form
|
||||
- name: enable_safety_checker
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
label:
|
||||
en_US: Enable Safety Checker
|
||||
zh_Hans: 启用安全检查器
|
||||
human_description:
|
||||
en_US: Enable or disable the safety checker.
|
||||
zh_Hans: 启用或禁用安全检查器。
|
||||
form: form
|
||||
- name: sync_mode
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Sync Mode
|
||||
zh_Hans: 同步模式
|
||||
human_description:
|
||||
en_US: >
|
||||
If set to true, the function will wait for the image to be generated and uploaded before returning the response.
|
||||
This will increase the latency but allows you to get the image directly in the response without going through the CDN.
|
||||
zh_Hans: >
|
||||
如果设置为 true,函数将在生成并上传图像后再返回响应。
|
||||
这将增加函数的延迟,但可以让您直接在响应中获取图像,而无需通过 CDN。
|
||||
form: form
|
|
@ -0,0 +1,47 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class Flux11ProUltraTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"Authorization": f"Key {self.runtime.credentials['fal_api_key']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
sanitized_prompt = prompt.replace("\\", "") # Remove backslashes from the prompt which may cause errors
|
||||
|
||||
payload = {
|
||||
"prompt": sanitized_prompt,
|
||||
"seed": tool_parameters.get("seed"),
|
||||
"sync_mode": tool_parameters.get("sync_mode", False),
|
||||
"num_images": tool_parameters.get("num_images", 1),
|
||||
"enable_safety_checker": tool_parameters.get("enable_safety_checker", True),
|
||||
"safety_tolerance": str(tool_parameters.get("safety_tolerance", "2")),
|
||||
"aspect_ratio": tool_parameters.get("aspect_ratio", "16:9"),
|
||||
"raw": tool_parameters.get("raw", False),
|
||||
}
|
||||
|
||||
url = "https://fal.run/fal-ai/flux-pro/v1.1-ultra"
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response: {response.text}")
|
||||
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
|
||||
for image_info in res.get("images", []):
|
||||
image_url = image_info.get("url")
|
||||
if image_url:
|
||||
result.append(self.create_image_message(image=image_url, save_as=self.VariableKey.IMAGE.value))
|
||||
|
||||
return result
|
|
@ -0,0 +1,162 @@
|
|||
identity:
|
||||
name: flux_1_1_pro_ultra
|
||||
author: Kalo Chin
|
||||
label:
|
||||
en_US: FLUX 1.1 [pro] ultra
|
||||
zh_Hans: FLUX 1.1 [pro] ultra
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: FLUX 1.1 [pro] ultra is the newest version of FLUX 1.1 [pro], maintaining professional-grade image quality while delivering up to 2K resolution with improved photo realism.
|
||||
zh_Hans: FLUX 1.1 [pro] ultra 是 FLUX 1.1 [pro] 的最新版本,保持了专业级的图像质量,同时以改进的照片真实感提供高达 2K 的分辨率。
|
||||
llm: This tool generates images from prompts using FAL's FLUX 1.1 [pro] ultra model.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: The text prompt used to generate the image.
|
||||
zh_Hans: 用于生成图像的文本提示。
|
||||
llm_description: This prompt text will be used to generate the image.
|
||||
form: llm
|
||||
- name: aspect_ratio
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: '21:9'
|
||||
label:
|
||||
en_US: '21:9'
|
||||
zh_Hans: '21:9'
|
||||
- value: '16:9'
|
||||
label:
|
||||
en_US: '16:9'
|
||||
zh_Hans: '16:9'
|
||||
- value: '4:3'
|
||||
label:
|
||||
en_US: '4:3'
|
||||
zh_Hans: '4:3'
|
||||
- value: '1:1'
|
||||
label:
|
||||
en_US: '1:1'
|
||||
zh_Hans: '1:1'
|
||||
- value: '3:4'
|
||||
label:
|
||||
en_US: '3:4'
|
||||
zh_Hans: '3:4'
|
||||
- value: '9:16'
|
||||
label:
|
||||
en_US: '9:16'
|
||||
zh_Hans: '9:16'
|
||||
- value: '9:21'
|
||||
label:
|
||||
en_US: '9:21'
|
||||
zh_Hans: '9:21'
|
||||
default: '16:9'
|
||||
label:
|
||||
en_US: Aspect Ratio
|
||||
zh_Hans: 纵横比
|
||||
human_description:
|
||||
en_US: The aspect ratio of the generated image.
|
||||
zh_Hans: 生成图像的宽高比。
|
||||
form: form
|
||||
- name: num_images
|
||||
type: number
|
||||
required: false
|
||||
default: 1
|
||||
min: 1
|
||||
max: 1
|
||||
label:
|
||||
en_US: Number of Images
|
||||
zh_Hans: 图片数量
|
||||
human_description:
|
||||
en_US: The number of images to generate.
|
||||
zh_Hans: 要生成的图像数量。
|
||||
form: form
|
||||
- name: safety_tolerance
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: "1"
|
||||
label:
|
||||
en_US: "1 (Most strict)"
|
||||
zh_Hans: "1(最严格)"
|
||||
- value: "2"
|
||||
label:
|
||||
en_US: "2"
|
||||
zh_Hans: "2"
|
||||
- value: "3"
|
||||
label:
|
||||
en_US: "3"
|
||||
zh_Hans: "3"
|
||||
- value: "4"
|
||||
label:
|
||||
en_US: "4"
|
||||
zh_Hans: "4"
|
||||
- value: "5"
|
||||
label:
|
||||
en_US: "5"
|
||||
zh_Hans: "5"
|
||||
- value: "6"
|
||||
label:
|
||||
en_US: "6 (Most permissive)"
|
||||
zh_Hans: "6(最宽松)"
|
||||
default: '2'
|
||||
label:
|
||||
en_US: Safety Tolerance
|
||||
zh_Hans: 安全容忍度
|
||||
human_description:
|
||||
en_US: The safety tolerance level for the generated image. 1 being the most strict and 6 being the most permissive.
|
||||
zh_Hans: 生成图像的安全容忍级别,1 为最严格,6 为最宽松。
|
||||
form: form
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
min: 0
|
||||
max: 9999999999
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子
|
||||
human_description:
|
||||
en_US: The same seed and prompt can produce similar images.
|
||||
zh_Hans: 相同的种子和提示词可以生成相似的图像。
|
||||
form: form
|
||||
- name: raw
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Raw Mode
|
||||
zh_Hans: 原始模式
|
||||
human_description:
|
||||
en_US: Generate less processed, more natural-looking images.
|
||||
zh_Hans: 生成较少处理、更自然的图像。
|
||||
form: form
|
||||
- name: enable_safety_checker
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
label:
|
||||
en_US: Enable Safety Checker
|
||||
zh_Hans: 启用安全检查器
|
||||
human_description:
|
||||
en_US: Enable or disable the safety checker.
|
||||
zh_Hans: 启用或禁用安全检查器。
|
||||
form: form
|
||||
- name: sync_mode
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Sync Mode
|
||||
zh_Hans: 同步模式
|
||||
human_description:
|
||||
en_US: >
|
||||
If set to true, the function will wait for the image to be generated and uploaded before returning the response.
|
||||
This will increase the latency but allows you to get the image directly in the response without going through the CDN.
|
||||
zh_Hans: >
|
||||
如果设置为 true,函数将在生成并上传图像后才返回响应。
|
||||
这将增加延迟,但允许您直接在响应中获取图像,而无需通过 CDN。
|
||||
form: form
|
47
api/core/tools/provider/builtin/fal/tools/flux_1_dev.py
Normal file
47
api/core/tools/provider/builtin/fal/tools/flux_1_dev.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class Flux1DevTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"Authorization": f"Key {self.runtime.credentials['fal_api_key']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
sanitized_prompt = prompt.replace("\\", "") # Remove backslashes from the prompt which may cause errors
|
||||
|
||||
payload = {
|
||||
"prompt": sanitized_prompt,
|
||||
"image_size": tool_parameters.get("image_size", "landscape_4_3"),
|
||||
"num_inference_steps": tool_parameters.get("num_inference_steps", 28),
|
||||
"guidance_scale": tool_parameters.get("guidance_scale", 3.5),
|
||||
"seed": tool_parameters.get("seed"),
|
||||
"num_images": tool_parameters.get("num_images", 1),
|
||||
"enable_safety_checker": tool_parameters.get("enable_safety_checker", True),
|
||||
"sync_mode": tool_parameters.get("sync_mode", False),
|
||||
}
|
||||
|
||||
url = "https://fal.run/fal-ai/flux/dev"
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response: {response.text}")
|
||||
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
|
||||
for image_info in res.get("images", []):
|
||||
image_url = image_info.get("url")
|
||||
if image_url:
|
||||
result.append(self.create_image_message(image=image_url, save_as=self.VariableKey.IMAGE.value))
|
||||
|
||||
return result
|
137
api/core/tools/provider/builtin/fal/tools/flux_1_dev.yaml
Normal file
137
api/core/tools/provider/builtin/fal/tools/flux_1_dev.yaml
Normal file
|
@ -0,0 +1,137 @@
|
|||
identity:
|
||||
name: flux_1_dev
|
||||
author: Kalo Chin
|
||||
label:
|
||||
en_US: FLUX.1 [dev]
|
||||
zh_Hans: FLUX.1 [dev]
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: FLUX.1 [dev] is a 12 billion parameter flow transformer that generates high-quality images from text. It is suitable for personal and commercial use.
|
||||
zh_Hans: FLUX.1 [dev] 是一个拥有120亿参数的流动变换模型,可以从文本生成高质量的图像。适用于个人和商业用途。
|
||||
llm: This tool generates images from prompts using FAL's FLUX.1 [dev] model.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: The text prompt used to generate the image.
|
||||
zh_Hans: 用于生成图片的文字提示词。
|
||||
llm_description: This prompt text will be used to generate the image.
|
||||
form: llm
|
||||
- name: image_size
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: square_hd
|
||||
label:
|
||||
en_US: Square HD
|
||||
zh_Hans: 方形高清
|
||||
- value: square
|
||||
label:
|
||||
en_US: Square
|
||||
zh_Hans: 方形
|
||||
- value: portrait_4_3
|
||||
label:
|
||||
en_US: Portrait 4:3
|
||||
zh_Hans: 竖屏 4:3
|
||||
- value: portrait_16_9
|
||||
label:
|
||||
en_US: Portrait 16:9
|
||||
zh_Hans: 竖屏 16:9
|
||||
- value: landscape_4_3
|
||||
label:
|
||||
en_US: Landscape 4:3
|
||||
zh_Hans: 横屏 4:3
|
||||
- value: landscape_16_9
|
||||
label:
|
||||
en_US: Landscape 16:9
|
||||
zh_Hans: 横屏 16:9
|
||||
default: landscape_4_3
|
||||
label:
|
||||
en_US: Image Size
|
||||
zh_Hans: 图片大小
|
||||
human_description:
|
||||
en_US: The size of the generated image.
|
||||
zh_Hans: 生成图像的尺寸。
|
||||
form: form
|
||||
- name: num_images
|
||||
type: number
|
||||
required: false
|
||||
default: 1
|
||||
min: 1
|
||||
max: 4
|
||||
label:
|
||||
en_US: Number of Images
|
||||
zh_Hans: 图片数量
|
||||
human_description:
|
||||
en_US: The number of images to generate.
|
||||
zh_Hans: 要生成的图片数量。
|
||||
form: form
|
||||
- name: num_inference_steps
|
||||
type: number
|
||||
required: false
|
||||
default: 28
|
||||
min: 1
|
||||
max: 50
|
||||
label:
|
||||
en_US: Num Inference Steps
|
||||
zh_Hans: 推理步数
|
||||
human_description:
|
||||
en_US: The number of inference steps to perform. More steps produce higher quality but take longer.
|
||||
zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。
|
||||
form: form
|
||||
- name: guidance_scale
|
||||
type: number
|
||||
required: false
|
||||
default: 3.5
|
||||
min: 0
|
||||
max: 20
|
||||
label:
|
||||
en_US: Guidance Scale
|
||||
zh_Hans: 指导强度
|
||||
human_description:
|
||||
en_US: How closely the model should follow the prompt.
|
||||
zh_Hans: 模型对提示词的遵循程度。
|
||||
form: form
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
min: 0
|
||||
max: 9999999999
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子
|
||||
human_description:
|
||||
en_US: The same seed and prompt can produce similar images.
|
||||
zh_Hans: 相同的种子和提示可以产生相似的图像。
|
||||
form: form
|
||||
- name: enable_safety_checker
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
label:
|
||||
en_US: Enable Safety Checker
|
||||
zh_Hans: 启用安全检查器
|
||||
human_description:
|
||||
en_US: Enable or disable the safety checker.
|
||||
zh_Hans: 启用或禁用安全检查器。
|
||||
form: form
|
||||
- name: sync_mode
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Sync Mode
|
||||
zh_Hans: 同步模式
|
||||
human_description:
|
||||
en_US: >
|
||||
If set to true, the function will wait for the image to be generated and uploaded before returning the response.
|
||||
This will increase the latency but allows you to get the image directly in the response without going through the CDN.
|
||||
zh_Hans: >
|
||||
如果设置为 true,函数将在生成并上传图像后再返回响应。
|
||||
这将增加函数的延迟,但可以让您直接在响应中获取图像,而无需通过 CDN。
|
||||
form: form
|
47
api/core/tools/provider/builtin/fal/tools/flux_1_pro_new.py
Normal file
47
api/core/tools/provider/builtin/fal/tools/flux_1_pro_new.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class Flux1ProNewTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"Authorization": f"Key {self.runtime.credentials['fal_api_key']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
sanitized_prompt = prompt.replace("\\", "") # Remove backslashes that may cause errors
|
||||
|
||||
payload = {
|
||||
"prompt": sanitized_prompt,
|
||||
"image_size": tool_parameters.get("image_size", "landscape_4_3"),
|
||||
"num_inference_steps": tool_parameters.get("num_inference_steps", 28),
|
||||
"guidance_scale": tool_parameters.get("guidance_scale", 3.5),
|
||||
"seed": tool_parameters.get("seed"),
|
||||
"num_images": tool_parameters.get("num_images", 1),
|
||||
"safety_tolerance": tool_parameters.get("safety_tolerance", "2"),
|
||||
"sync_mode": tool_parameters.get("sync_mode", False),
|
||||
}
|
||||
|
||||
url = "https://fal.run/fal-ai/flux-pro/new"
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response: {response.text}")
|
||||
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
|
||||
for image_info in res.get("images", []):
|
||||
image_url = image_info.get("url")
|
||||
if image_url:
|
||||
result.append(self.create_image_message(image=image_url, save_as=self.VariableKey.IMAGE.value))
|
||||
|
||||
return result
|
164
api/core/tools/provider/builtin/fal/tools/flux_1_pro_new.yaml
Normal file
164
api/core/tools/provider/builtin/fal/tools/flux_1_pro_new.yaml
Normal file
|
@ -0,0 +1,164 @@
|
|||
identity:
|
||||
name: flux_1_pro_new
|
||||
author: Kalo Chin
|
||||
label:
|
||||
en_US: FLUX.1 [pro] new
|
||||
zh_Hans: FLUX.1 [pro] new
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: FLUX.1 [pro] new is an accelerated version of FLUX.1 [pro], maintaining professional-grade image quality while delivering significantly faster generation speeds.
|
||||
zh_Hans: FLUX.1 [pro] new 是 FLUX.1 [pro] 的加速版本,在保持专业级图像质量的同时,大大提高了生成速度。
|
||||
llm: This tool generates images from prompts using FAL's FLUX.1 [pro] new model.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: The text prompt used to generate the image.
|
||||
zh_Hans: 用于生成图像的文本提示。
|
||||
llm_description: This prompt text will be used to generate the image.
|
||||
form: llm
|
||||
- name: image_size
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: square_hd
|
||||
label:
|
||||
en_US: Square HD
|
||||
zh_Hans: 正方形高清
|
||||
- value: square
|
||||
label:
|
||||
en_US: Square
|
||||
zh_Hans: 正方形
|
||||
- value: portrait_4_3
|
||||
label:
|
||||
en_US: Portrait 4:3
|
||||
zh_Hans: 竖屏 4:3
|
||||
- value: portrait_16_9
|
||||
label:
|
||||
en_US: Portrait 16:9
|
||||
zh_Hans: 竖屏 16:9
|
||||
- value: landscape_4_3
|
||||
label:
|
||||
en_US: Landscape 4:3
|
||||
zh_Hans: 横屏 4:3
|
||||
- value: landscape_16_9
|
||||
label:
|
||||
en_US: Landscape 16:9
|
||||
zh_Hans: 横屏 16:9
|
||||
default: landscape_4_3
|
||||
label:
|
||||
en_US: Image Size
|
||||
zh_Hans: 图像尺寸
|
||||
human_description:
|
||||
en_US: The size of the generated image.
|
||||
zh_Hans: 生成图像的尺寸。
|
||||
form: form
|
||||
- name: num_images
|
||||
type: number
|
||||
required: false
|
||||
default: 1
|
||||
min: 1
|
||||
max: 1
|
||||
label:
|
||||
en_US: Number of Images
|
||||
zh_Hans: 图像数量
|
||||
human_description:
|
||||
en_US: The number of images to generate.
|
||||
zh_Hans: 要生成的图像数量。
|
||||
form: form
|
||||
- name: num_inference_steps
|
||||
type: number
|
||||
required: false
|
||||
default: 28
|
||||
min: 1
|
||||
max: 50
|
||||
label:
|
||||
en_US: Num Inference Steps
|
||||
zh_Hans: 推理步数
|
||||
human_description:
|
||||
en_US: The number of inference steps to perform. More steps produce higher quality but take longer.
|
||||
zh_Hans: 执行的推理步数。步数越多,质量越高,但所需时间也更长。
|
||||
form: form
|
||||
- name: guidance_scale
|
||||
type: number
|
||||
required: false
|
||||
default: 3.5
|
||||
min: 0
|
||||
max: 20
|
||||
label:
|
||||
en_US: Guidance Scale
|
||||
zh_Hans: 指导强度
|
||||
human_description:
|
||||
en_US: How closely the model should follow the prompt.
|
||||
zh_Hans: 模型对提示词的遵循程度。
|
||||
form: form
|
||||
- name: safety_tolerance
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: "1"
|
||||
label:
|
||||
en_US: "1 (Most strict)"
|
||||
zh_Hans: "1(最严格)"
|
||||
- value: "2"
|
||||
label:
|
||||
en_US: "2"
|
||||
zh_Hans: "2"
|
||||
- value: "3"
|
||||
label:
|
||||
en_US: "3"
|
||||
zh_Hans: "3"
|
||||
- value: "4"
|
||||
label:
|
||||
en_US: "4"
|
||||
zh_Hans: "4"
|
||||
- value: "5"
|
||||
label:
|
||||
en_US: "5"
|
||||
zh_Hans: "5"
|
||||
- value: "6"
|
||||
label:
|
||||
en_US: "6 (Most permissive)"
|
||||
zh_Hans: "6(最宽松)"
|
||||
default: "2"
|
||||
label:
|
||||
en_US: Safety Tolerance
|
||||
zh_Hans: 安全容忍度
|
||||
human_description:
|
||||
en_US: >
|
||||
The safety tolerance level for the generated image. 1 being the most strict and 5 being the most permissive.
|
||||
zh_Hans: >
|
||||
生成图像的安全容忍级别。1 是最严格,6 是最宽松。
|
||||
form: form
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
min: 0
|
||||
max: 9999999999
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子
|
||||
human_description:
|
||||
en_US: The same seed and prompt can produce similar images.
|
||||
zh_Hans: 相同的种子和提示词可以生成相似的图像。
|
||||
form: form
|
||||
- name: sync_mode
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Sync Mode
|
||||
zh_Hans: 同步模式
|
||||
human_description:
|
||||
en_US: >
|
||||
If set to true, the function will wait for the image to be generated and uploaded before returning the response.
|
||||
This will increase the latency but allows you to get the image directly in the response without going through the CDN.
|
||||
zh_Hans: >
|
||||
如果设置为 true,函数将在生成并上传图像后才返回响应。
|
||||
这将增加延迟,但允许您直接在响应中获取图像,而无需通过 CDN。
|
||||
form: form
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Any
|
||||
|
||||
import openai
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
@ -10,6 +11,7 @@ class PodcastGeneratorProvider(BuiltinToolProviderController):
|
|||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
tts_service = credentials.get("tts_service")
|
||||
api_key = credentials.get("api_key")
|
||||
base_url = credentials.get("openai_base_url")
|
||||
|
||||
if not tts_service:
|
||||
raise ToolProviderCredentialValidationError("TTS service is not specified")
|
||||
|
@ -17,13 +19,16 @@ class PodcastGeneratorProvider(BuiltinToolProviderController):
|
|||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError("API key is missing")
|
||||
|
||||
if base_url:
|
||||
base_url = str(URL(base_url) / "v1")
|
||||
|
||||
if tts_service == "openai":
|
||||
self._validate_openai_credentials(api_key)
|
||||
self._validate_openai_credentials(api_key, base_url)
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}")
|
||||
|
||||
def _validate_openai_credentials(self, api_key: str) -> None:
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
def _validate_openai_credentials(self, api_key: str, base_url: str | None) -> None:
|
||||
client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
||||
try:
|
||||
# We're using a simple API call to validate the credentials
|
||||
client.models.list()
|
||||
|
|
|
@ -118,11 +118,11 @@ class FileSegment(Segment):
|
|||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return str(self.value)
|
||||
return ""
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return str(self.value)
|
||||
return ""
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
|
@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment):
|
|||
for item in self.value:
|
||||
items.append(item.markdown)
|
||||
return "\n".join(items)
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return ""
|
||||
|
|
|
@ -143,14 +143,14 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
|||
|
||||
def _extract_text_from_plain_text(file_content: bytes) -> str:
|
||||
try:
|
||||
return file_content.decode("utf-8")
|
||||
return file_content.decode("utf-8", "ignore")
|
||||
except UnicodeDecodeError as e:
|
||||
raise TextExtractionError("Failed to decode plain text file") from e
|
||||
|
||||
|
||||
def _extract_text_from_json(file_content: bytes) -> str:
|
||||
try:
|
||||
json_data = json.loads(file_content.decode("utf-8"))
|
||||
json_data = json.loads(file_content.decode("utf-8", "ignore"))
|
||||
return json.dumps(json_data, indent=2, ensure_ascii=False)
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||
raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e
|
||||
|
@ -159,7 +159,7 @@ def _extract_text_from_json(file_content: bytes) -> str:
|
|||
def _extract_text_from_yaml(file_content: bytes) -> str:
|
||||
"""Extract the content from yaml file"""
|
||||
try:
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8"))
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore"))
|
||||
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
|
||||
except (UnicodeDecodeError, yaml.YAMLError) as e:
|
||||
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
|
||||
|
@ -217,7 +217,7 @@ def _extract_text_from_file(file: File):
|
|||
|
||||
def _extract_text_from_csv(file_content: bytes) -> str:
|
||||
try:
|
||||
csv_file = io.StringIO(file_content.decode("utf-8"))
|
||||
csv_file = io.StringIO(file_content.decode("utf-8", "ignore"))
|
||||
csv_reader = csv.reader(csv_file)
|
||||
rows = list(csv_reader)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
|
@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
|||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables import (
|
||||
|
@ -30,10 +36,13 @@ from core.variables import (
|
|||
FileSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
SegmentGroup,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
@ -62,14 +71,18 @@ from .exc import (
|
|||
InvalidVariableTypeError,
|
||||
LLMModeRequiredError,
|
||||
LLMNodeError,
|
||||
MemoryRolePrefixRequiredError,
|
||||
ModelNotExistError,
|
||||
NoPromptFoundError,
|
||||
NotSupportedPromptTypeError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_data_cls = LLMNodeData
|
||||
|
@ -131,9 +144,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
query = None
|
||||
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
system_query=query,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
user_query=query,
|
||||
user_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
|
@ -181,6 +193,17 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
)
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {self.node_id} failed to run: {e}")
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
|
||||
|
@ -203,7 +226,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
self,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
db.session.close()
|
||||
|
@ -519,9 +542,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
def _fetch_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
system_query: str | None = None,
|
||||
inputs: dict[str, str] | None = None,
|
||||
files: Sequence["File"],
|
||||
user_query: str | None = None,
|
||||
user_files: Sequence["File"],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
|
@ -529,60 +551,185 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
memory_config: MemoryConfig | None = None,
|
||||
vision_enabled: bool = False,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
inputs = inputs or {}
|
||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||
prompt_messages = []
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=system_query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
)
|
||||
stop = model_config.stop
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
prompt_messages.extend(self._handle_list_messages(messages=prompt_template, context=context))
|
||||
|
||||
# Get memory messages for chat mode
|
||||
memory_messages = self._handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
# Extend prompt_messages with memory messages
|
||||
prompt_messages.extend(memory_messages)
|
||||
|
||||
# Add current query to the prompt messages
|
||||
if user_query:
|
||||
prompt_messages.append(UserPromptMessage(content=[TextPromptMessageContent(data=user_query)]))
|
||||
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
# For completion model
|
||||
prompt_messages.extend(self._handle_completion_template(template=prompt_template, context=context))
|
||||
|
||||
# Get memory text for completion model
|
||||
memory_text = self._handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
|
||||
# Add current query to the prompt message
|
||||
if user_query:
|
||||
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
else:
|
||||
errmsg = f"Prompt type {type(prompt_template)} is not supported"
|
||||
logger.warning(errmsg)
|
||||
raise NotSupportedPromptTypeError(errmsg)
|
||||
|
||||
if vision_enabled and user_files:
|
||||
file_prompts = []
|
||||
for file in user_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Filter prompt messages
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.is_empty():
|
||||
continue
|
||||
|
||||
if not isinstance(prompt_message.content, str):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message_content = []
|
||||
for content_item in prompt_message.content or []:
|
||||
# Skip image if vision is disabled
|
||||
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
|
||||
for content_item in prompt_message.content:
|
||||
# Skip content if features are not defined
|
||||
if not model_config.model_schema.features:
|
||||
if content_item.type != PromptMessageContentType.TEXT:
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
continue
|
||||
|
||||
if isinstance(content_item, ImagePromptMessageContent):
|
||||
# Override vision config if LLM node has vision config,
|
||||
# cuz vision detail is related to the configuration from FileUpload feature.
|
||||
content_item.detail = vision_detail
|
||||
prompt_message_content.append(content_item)
|
||||
elif isinstance(
|
||||
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
|
||||
# Skip content if corresponding feature is not supported
|
||||
if (
|
||||
(
|
||||
content_item.type == PromptMessageContentType.IMAGE
|
||||
and (not vision_enabled or ModelFeature.VISION not in model_config.model_schema.features)
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.DOCUMENT
|
||||
and ModelFeature.DOCUMENT not in model_config.model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.VIDEO
|
||||
and ModelFeature.VIDEO not in model_config.model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.AUDIO
|
||||
and ModelFeature.AUDIO not in model_config.model_schema.features
|
||||
)
|
||||
):
|
||||
prompt_message_content.append(content_item)
|
||||
|
||||
if len(prompt_message_content) > 1:
|
||||
prompt_message.content = prompt_message_content
|
||||
elif (
|
||||
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
|
||||
else:
|
||||
prompt_message.content = prompt_message_content
|
||||
if prompt_message.is_empty():
|
||||
continue
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if not filtered_prompt_messages:
|
||||
if len(filtered_prompt_messages) == 0:
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
stop = model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
self,
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = self._calculate_rest_token([], model_config)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
return memory_messages
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
self,
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
if memory and memory_config:
|
||||
rest_tokens = self._calculate_rest_token([], model_config)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = memory.get_history_prompt_text(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
def _calculate_rest_token(
|
||||
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
@classmethod
|
||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
@ -715,3 +862,123 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
}
|
||||
},
|
||||
}
|
||||
|
||||
def _handle_list_messages(
|
||||
self, *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str]
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinjia2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
)
|
||||
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
segment_group = _render_basic_message(
|
||||
template=message.text,
|
||||
context=context,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=self.node_data.vision.configs.detail
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
if isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=self.node_data.vision.configs.detail
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = UserPromptMessage(content=file_contents)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _handle_completion_template(
|
||||
self, *, template: LLMNodeCompletionModelPromptTemplate, context: Optional[str]
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages = []
|
||||
if template.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=template.jinja2_text or "",
|
||||
jinjia2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
)
|
||||
else:
|
||||
result_text = _render_basic_message(
|
||||
template=template.text,
|
||||
context=context,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
).text
|
||||
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
|
||||
prompt_messages.append(prompt_message)
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||
case PromptMessageRole.SYSTEM:
|
||||
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
|
||||
|
||||
def _render_jinja2_message(
|
||||
*,
|
||||
template: str,
|
||||
jinjia2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
):
|
||||
if not template:
|
||||
return ""
|
||||
|
||||
jinjia2_inputs = {}
|
||||
for jinja2_variable in jinjia2_variables:
|
||||
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||
code_execute_resp = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code=template,
|
||||
inputs=jinjia2_inputs,
|
||||
)
|
||||
result_text = code_execute_resp["result"]
|
||||
return result_text
|
||||
|
||||
|
||||
def _render_basic_message(
|
||||
*,
|
||||
template: str,
|
||||
context: str | None,
|
||||
variable_pool: VariablePool,
|
||||
) -> SegmentGroup:
|
||||
if not template:
|
||||
return SegmentGroup(value=[])
|
||||
|
||||
if context:
|
||||
template = template.replace("{#context#}", context)
|
||||
|
||||
return variable_pool.convert_template(template)
|
||||
|
|
|
@ -21,7 +21,7 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||
fi
|
||||
|
||||
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL} \
|
||||
-Q ${CELERY_QUEUES:-dataset,generation,mail,ops_trace,app_deletion}
|
||||
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion}
|
||||
|
||||
elif [[ "${MODE}" == "beat" ]]; then
|
||||
exec celery -A app.celery beat --loglevel ${LOG_LEVEL}
|
||||
|
|
2
api/poetry.lock
generated
2
api/poetry.lock
generated
|
@ -11020,4 +11020,4 @@ cffi = ["cffi (>=1.11)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "75a954fab629ce88fced727706a253173983bb08647b79fa342d0fc5b3b4743a"
|
||||
content-hash = "0ab603323ea1d83690d4ee61e6d199a2bca6f3e2cc4b454a4ebf99aa6f6907bd"
|
||||
|
|
|
@ -265,11 +265,11 @@ weaviate-client = "~3.21.0"
|
|||
optional = true
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
coverage = "~7.2.4"
|
||||
faker = "~32.1.0"
|
||||
pytest = "~8.3.2"
|
||||
pytest-benchmark = "~4.0.0"
|
||||
pytest-env = "~1.1.3"
|
||||
pytest-mock = "~3.14.0"
|
||||
faker = "^32.1.0"
|
||||
|
||||
############################################################
|
||||
# [ Lint ] dependency group
|
||||
|
|
|
@ -1458,6 +1458,7 @@ class SegmentService:
|
|||
pre_segment_data_list = []
|
||||
segment_data_list = []
|
||||
keywords_list = []
|
||||
position = max_position + 1 if max_position else 1
|
||||
for segment_item in segments:
|
||||
content = segment_item["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
@ -1475,7 +1476,7 @@ class SegmentService:
|
|||
document_id=document.id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
position=position,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
|
@ -1490,6 +1491,7 @@ class SegmentService:
|
|||
increment_word_count += segment_document.word_count
|
||||
db.session.add(segment_document)
|
||||
segment_data_list.append(segment_document)
|
||||
position += 1
|
||||
|
||||
pre_segment_data_list.append(segment_document)
|
||||
if "keywords" in segment_item:
|
||||
|
|
|
@ -25,7 +25,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
|||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
logging.info(click.style("Dataset is not found: {}".format(dataset_id), fg="yellow"))
|
||||
return
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
|
|
|
@ -1,125 +1,480 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
VisionConfigOptions,
|
||||
)
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from models.enums import UserFrom
|
||||
from models.provider import ProviderType
|
||||
from models.workflow import WorkflowType
|
||||
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
|
||||
|
||||
|
||||
class TestLLMNode:
|
||||
@pytest.fixture
|
||||
def llm_node(self):
|
||||
data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[],
|
||||
memory=None,
|
||||
context=ContextConfig(enabled=False),
|
||||
vision=VisionConfig(
|
||||
enabled=True,
|
||||
configs=VisionConfigOptions(
|
||||
variable_selector=["sys", "files"],
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
return node
|
||||
class MockTokenBufferMemory:
|
||||
def __init__(self, history_messages=None):
|
||||
self.history_messages = history_messages or []
|
||||
|
||||
def test_fetch_files_with_file_segment(self, llm_node):
|
||||
file = File(
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
if message_limit is not None:
|
||||
return self.history_messages[-message_limit * 2 :]
|
||||
return self.history_messages
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_node():
|
||||
data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[],
|
||||
memory=None,
|
||||
context=ContextConfig(enabled=False),
|
||||
vision=VisionConfig(
|
||||
enabled=True,
|
||||
configs=VisionConfigOptions(
|
||||
variable_selector=["sys", "files"],
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_config():
|
||||
# Create actual provider and model type instances
|
||||
model_provider_factory = ModelProviderFactory()
|
||||
provider_instance = model_provider_factory.get_provider_instance("openai")
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Create a ProviderModelBundle
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
tenant_id="1",
|
||||
provider=provider_instance.get_provider_schema(),
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
using_provider_type=ProviderType.CUSTOM,
|
||||
system_configuration=SystemConfiguration(enabled=False),
|
||||
custom_configuration=CustomConfiguration(provider=None),
|
||||
model_settings=[],
|
||||
),
|
||||
provider_instance=provider_instance,
|
||||
model_type_instance=model_type_instance,
|
||||
)
|
||||
|
||||
# Create and return a ModelConfigWithCredentialsEntity
|
||||
return ModelConfigWithCredentialsEntity(
|
||||
provider="openai",
|
||||
model="gpt-3.5-turbo",
|
||||
model_schema=AIModelEntity(
|
||||
model="gpt-3.5-turbo",
|
||||
label=I18nObject(en_US="GPT-3.5 Turbo"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={},
|
||||
),
|
||||
mode="chat",
|
||||
credentials={},
|
||||
parameters={},
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_files_with_file_segment(llm_node):
|
||||
file = File(
|
||||
id="1",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
)
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == [file]
|
||||
|
||||
|
||||
def test_fetch_files_with_array_file_segment(llm_node):
|
||||
files = [
|
||||
File(
|
||||
id="1",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test.jpg",
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
),
|
||||
File(
|
||||
id="2",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test2.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="2",
|
||||
),
|
||||
]
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == files
|
||||
|
||||
|
||||
def test_fetch_files_with_none_segment(llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_fetch_files_with_array_any_segment(llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_fetch_files_with_non_existent_variable(llm_node):
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
||||
prompt_template = []
|
||||
llm_node.node_data.prompt_template = prompt_template
|
||||
|
||||
fake_vision_detail = faker.random_element(
|
||||
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||
)
|
||||
fake_remote_url = faker.url()
|
||||
files = [
|
||||
File(
|
||||
id="1",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
)
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||
]
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == [file]
|
||||
fake_query = faker.sentence()
|
||||
|
||||
def test_fetch_files_with_array_file_segment(self, llm_node):
|
||||
files = [
|
||||
File(
|
||||
id="1",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
),
|
||||
File(
|
||||
id="2",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test2.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="2",
|
||||
),
|
||||
]
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||
user_query=fake_query,
|
||||
user_files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
memory_config=None,
|
||||
vision_enabled=False,
|
||||
vision_detail=fake_vision_detail,
|
||||
)
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == files
|
||||
assert prompt_messages == [UserPromptMessage(content=fake_query)]
|
||||
|
||||
def test_fetch_files_with_none_segment(self, llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
# Setup dify config
|
||||
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
|
||||
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
|
||||
|
||||
def test_fetch_files_with_array_any_segment(self, llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||
# Generate fake values for prompt template
|
||||
fake_assistant_prompt = faker.sentence()
|
||||
fake_query = faker.sentence()
|
||||
fake_context = faker.sentence()
|
||||
fake_window_size = faker.random_int(min=1, max=3)
|
||||
fake_vision_detail = faker.random_element(
|
||||
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||
)
|
||||
fake_remote_url = faker.url()
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
# Setup mock memory with history messages
|
||||
mock_history = [
|
||||
UserPromptMessage(content=faker.sentence()),
|
||||
AssistantPromptMessage(content=faker.sentence()),
|
||||
UserPromptMessage(content=faker.sentence()),
|
||||
AssistantPromptMessage(content=faker.sentence()),
|
||||
UserPromptMessage(content=faker.sentence()),
|
||||
AssistantPromptMessage(content=faker.sentence()),
|
||||
]
|
||||
|
||||
def test_fetch_files_with_non_existent_variable(self, llm_node):
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
# Setup memory configuration
|
||||
memory_config = MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||
window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
|
||||
query_prompt_template=None,
|
||||
)
|
||||
|
||||
memory = MockTokenBufferMemory(history_messages=mock_history)
|
||||
|
||||
# Test scenarios covering different file input combinations
|
||||
test_scenarios = [
|
||||
LLMNodeTestScenario(
|
||||
description="No files",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
features=[],
|
||||
vision_enabled=False,
|
||||
vision_detail=None,
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=fake_context,
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
),
|
||||
LLMNodeChatModelMessage(
|
||||
text="{#context#}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
),
|
||||
LLMNodeChatModelMessage(
|
||||
text=fake_assistant_prompt,
|
||||
role=PromptMessageRole.ASSISTANT,
|
||||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
expected_messages=[
|
||||
SystemPromptMessage(content=fake_context),
|
||||
UserPromptMessage(content=fake_context),
|
||||
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||
]
|
||||
+ mock_history[fake_window_size * -2 :]
|
||||
+ [
|
||||
UserPromptMessage(content=fake_query),
|
||||
],
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="User files",
|
||||
user_query=fake_query,
|
||||
user_files=[
|
||||
File(
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
)
|
||||
],
|
||||
vision_enabled=True,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[ModelFeature.VISION],
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=fake_context,
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
),
|
||||
LLMNodeChatModelMessage(
|
||||
text="{#context#}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
),
|
||||
LLMNodeChatModelMessage(
|
||||
text=fake_assistant_prompt,
|
||||
role=PromptMessageRole.ASSISTANT,
|
||||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
expected_messages=[
|
||||
SystemPromptMessage(content=fake_context),
|
||||
UserPromptMessage(content=fake_context),
|
||||
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||
]
|
||||
+ mock_history[fake_window_size * -2 :]
|
||||
+ [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=fake_query),
|
||||
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||
]
|
||||
),
|
||||
],
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="Prompt template with variable selector of File",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
vision_enabled=True,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[ModelFeature.VISION],
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="{{#input.image#}}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
expected_messages=[
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||
]
|
||||
),
|
||||
]
|
||||
+ mock_history[fake_window_size * -2 :]
|
||||
+ [UserPromptMessage(content=fake_query)],
|
||||
file_variables={
|
||||
"input.image": File(
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
)
|
||||
},
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="Prompt template with variable selector of File without vision feature",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
vision_enabled=True,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[],
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="{{#input.image#}}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||
file_variables={
|
||||
"input.image": File(
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
)
|
||||
},
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="Prompt template with variable selector of File with video file and vision feature",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
vision_enabled=True,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[ModelFeature.VISION],
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="{{#input.image#}}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||
file_variables={
|
||||
"input.image": File(
|
||||
tenant_id="test",
|
||||
type=FileType.VIDEO,
|
||||
filename="test1.mp4",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
extension="mp4",
|
||||
)
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
for scenario in test_scenarios:
|
||||
model_config.model_schema.features = scenario.features
|
||||
|
||||
for k, v in scenario.file_variables.items():
|
||||
selector = k.split(".")
|
||||
llm_node.graph_runtime_state.variable_pool.add(selector, v)
|
||||
|
||||
# Call the method under test
|
||||
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||
user_query=scenario.user_query,
|
||||
user_files=scenario.user_files,
|
||||
context=fake_context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=scenario.prompt_template,
|
||||
memory_config=memory_config,
|
||||
vision_enabled=scenario.vision_enabled,
|
||||
vision_detail=scenario.vision_detail,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
|
||||
assert (
|
||||
prompt_messages == scenario.expected_messages
|
||||
), f"Message content mismatch in scenario: {scenario.description}"
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
|
||||
|
||||
|
||||
class LLMNodeTestScenario(BaseModel):
|
||||
"""Test scenario for LLM node testing."""
|
||||
|
||||
description: str = Field(..., description="Description of the test scenario")
|
||||
user_query: str = Field(..., description="User query input")
|
||||
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
||||
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
|
||||
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
|
||||
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
|
||||
window_size: int = Field(..., description="Window size for memory")
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
|
||||
file_variables: Mapping[str, File | Sequence[File]] = Field(
|
||||
default_factory=dict, description="List of file variables"
|
||||
)
|
||||
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")
|
|
@ -140,6 +140,17 @@ def test_extract_text_from_plain_text():
|
|||
assert text == "Hello, world!"
|
||||
|
||||
|
||||
def test_extract_text_from_plain_text_non_utf8():
|
||||
import tempfile
|
||||
|
||||
non_utf8_content = b"Hello, world\xa9." # \xA9 represents © in Latin-1
|
||||
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
|
||||
temp_file.write(non_utf8_content)
|
||||
temp_file.seek(0)
|
||||
text = _extract_text_from_plain_text(temp_file.read())
|
||||
assert text == "Hello, world."
|
||||
|
||||
|
||||
@patch("pypdfium2.PdfDocument")
|
||||
def test_extract_text_from_pdf(mock_pdf_document):
|
||||
mock_page = Mock()
|
||||
|
|
|
@ -689,6 +689,9 @@ TEMPLATE_TRANSFORM_MAX_LENGTH=80000
|
|||
CODE_MAX_STRING_ARRAY_LENGTH=30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH=30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
CODE_EXECUTION_CONNECT_TIMEOUT=10
|
||||
CODE_EXECUTION_READ_TIMEOUT=60
|
||||
CODE_EXECUTION_WRITE_TIMEOUT=10
|
||||
|
||||
# Workflow runtime configuration
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
|
|
|
@ -244,6 +244,9 @@ x-shared-env: &shared-api-worker-env
|
|||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5}
|
||||
CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194}
|
||||
CODE_EXECUTION_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox}
|
||||
CODE_EXECUTION_CONNECT_TIMEOUT: ${CODE_EXECUTION_CONNECT_TIMEOUT:-10}
|
||||
CODE_EXECUTION_READ_TIMEOUT: ${CODE_EXECUTION_READ_TIMEOUT:-60}
|
||||
CODE_EXECUTION_WRITE_TIMEOUT: ${CODE_EXECUTION_WRITE_TIMEOUT:-10}
|
||||
CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807}
|
||||
CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808}
|
||||
CODE_MAX_DEPTH: ${CODE_MAX_DEPTH:-5}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# base image
|
||||
FROM node:20.11-alpine3.19 AS base
|
||||
FROM node:20-alpine3.20 AS base
|
||||
LABEL maintainer="takatost@gmail.com"
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
|
|
|
@ -44,12 +44,6 @@ export const fileUpload: FileUpload = ({
|
|||
}
|
||||
|
||||
export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => {
|
||||
if (fileMimetype)
|
||||
return mime.getExtension(fileMimetype) || ''
|
||||
|
||||
if (isRemote)
|
||||
return ''
|
||||
|
||||
if (fileName) {
|
||||
const fileNamePair = fileName.split('.')
|
||||
const fileNamePairLength = fileNamePair.length
|
||||
|
@ -58,6 +52,12 @@ export const getFileExtension = (fileName: string, fileMimetype: string, isRemot
|
|||
return fileNamePair[fileNamePairLength - 1]
|
||||
}
|
||||
|
||||
if (fileMimetype)
|
||||
return mime.getExtension(fileMimetype) || ''
|
||||
|
||||
if (isRemote)
|
||||
return ''
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,6 @@ export const getInputVars = (text: string): ValueSelector[] => {
|
|||
export const FILE_EXTS: Record<string, string[]> = {
|
||||
[SupportUploadFileTypes.image]: ['JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG'],
|
||||
[SupportUploadFileTypes.document]: ['TXT', 'MD', 'MARKDOWN', 'PDF', 'HTML', 'XLSX', 'XLS', 'DOCX', 'CSV', 'EML', 'MSG', 'PPTX', 'PPT', 'XML', 'EPUB'],
|
||||
[SupportUploadFileTypes.audio]: ['MP3', 'M4A', 'WAV', 'WEBM', 'AMR'],
|
||||
[SupportUploadFileTypes.audio]: ['MP3', 'M4A', 'WAV', 'WEBM', 'AMR', 'MPGA'],
|
||||
[SupportUploadFileTypes.video]: ['MP4', 'MOV', 'MPEG', 'MPGA'],
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user