mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
fix the issue of the refine_switches at param being invalid in the Novita.AI tool (#7485)
This commit is contained in:
parent
66dfb5c89a
commit
0c99a3d0c5
|
@ -0,0 +1,73 @@
|
||||||
|
from novita_client import (
|
||||||
|
Txt2ImgV3Embedding,
|
||||||
|
Txt2ImgV3HiresFix,
|
||||||
|
Txt2ImgV3LoRA,
|
||||||
|
Txt2ImgV3Refiner,
|
||||||
|
V3TaskImage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NovitaAiToolBase:
|
||||||
|
def _extract_loras(self, loras_str: str):
|
||||||
|
if not loras_str:
|
||||||
|
return []
|
||||||
|
|
||||||
|
loras_ori_list = lora_str.strip().split(';')
|
||||||
|
result_list = []
|
||||||
|
for lora_str in loras_ori_list:
|
||||||
|
lora_info = lora_str.strip().split(',')
|
||||||
|
lora = Txt2ImgV3LoRA(
|
||||||
|
model_name=lora_info[0].strip(),
|
||||||
|
strength=float(lora_info[1]),
|
||||||
|
)
|
||||||
|
result_list.append(lora)
|
||||||
|
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def _extract_embeddings(self, embeddings_str: str):
|
||||||
|
if not embeddings_str:
|
||||||
|
return []
|
||||||
|
|
||||||
|
embeddings_ori_list = embeddings_str.strip().split(';')
|
||||||
|
result_list = []
|
||||||
|
for embedding_str in embeddings_ori_list:
|
||||||
|
embedding = Txt2ImgV3Embedding(
|
||||||
|
model_name=embedding_str.strip()
|
||||||
|
)
|
||||||
|
result_list.append(embedding)
|
||||||
|
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def _extract_hires_fix(self, hires_fix_str: str):
|
||||||
|
hires_fix_info = hires_fix_str.strip().split(',')
|
||||||
|
if 'upscaler' in hires_fix_info:
|
||||||
|
hires_fix = Txt2ImgV3HiresFix(
|
||||||
|
target_width=int(hires_fix_info[0]),
|
||||||
|
target_height=int(hires_fix_info[1]),
|
||||||
|
strength=float(hires_fix_info[2]),
|
||||||
|
upscaler=hires_fix_info[3].strip()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hires_fix = Txt2ImgV3HiresFix(
|
||||||
|
target_width=int(hires_fix_info[0]),
|
||||||
|
target_height=int(hires_fix_info[1]),
|
||||||
|
strength=float(hires_fix_info[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
return hires_fix
|
||||||
|
|
||||||
|
def _extract_refiner(self, switch_at: str):
|
||||||
|
refiner = Txt2ImgV3Refiner(
|
||||||
|
switch_at=float(switch_at)
|
||||||
|
)
|
||||||
|
return refiner
|
||||||
|
|
||||||
|
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
|
||||||
|
"""
|
||||||
|
is hit nsfw
|
||||||
|
"""
|
||||||
|
if image.nsfw_detection_result is None:
|
||||||
|
return False
|
||||||
|
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
|
||||||
|
return True
|
||||||
|
return False
|
|
@ -4,19 +4,15 @@ from typing import Any, Union
|
||||||
|
|
||||||
from novita_client import (
|
from novita_client import (
|
||||||
NovitaClient,
|
NovitaClient,
|
||||||
Txt2ImgV3Embedding,
|
|
||||||
Txt2ImgV3HiresFix,
|
|
||||||
Txt2ImgV3LoRA,
|
|
||||||
Txt2ImgV3Refiner,
|
|
||||||
V3TaskImage,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
from core.tools.errors import ToolProviderCredentialValidationError
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
|
||||||
class NovitaAiTxt2ImgTool(BuiltinTool):
|
class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
|
||||||
def _invoke(self,
|
def _invoke(self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tool_parameters: dict[str, Any],
|
tool_parameters: dict[str, Any],
|
||||||
|
@ -73,65 +69,19 @@ class NovitaAiTxt2ImgTool(BuiltinTool):
|
||||||
|
|
||||||
# process loras
|
# process loras
|
||||||
if 'loras' in res_parameters:
|
if 'loras' in res_parameters:
|
||||||
loras_ori_list = res_parameters.get('loras').strip().split(';')
|
res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
|
||||||
locals_list = []
|
|
||||||
for lora_str in loras_ori_list:
|
|
||||||
lora_info = lora_str.strip().split(',')
|
|
||||||
lora = Txt2ImgV3LoRA(
|
|
||||||
model_name=lora_info[0].strip(),
|
|
||||||
strength=float(lora_info[1]),
|
|
||||||
)
|
|
||||||
locals_list.append(lora)
|
|
||||||
|
|
||||||
res_parameters['loras'] = locals_list
|
|
||||||
|
|
||||||
# process embeddings
|
# process embeddings
|
||||||
if 'embeddings' in res_parameters:
|
if 'embeddings' in res_parameters:
|
||||||
embeddings_ori_list = res_parameters.get('embeddings').strip().split(';')
|
res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
|
||||||
locals_list = []
|
|
||||||
for embedding_str in embeddings_ori_list:
|
|
||||||
embedding = Txt2ImgV3Embedding(
|
|
||||||
model_name=embedding_str.strip()
|
|
||||||
)
|
|
||||||
locals_list.append(embedding)
|
|
||||||
|
|
||||||
res_parameters['embeddings'] = locals_list
|
|
||||||
|
|
||||||
# process hires_fix
|
# process hires_fix
|
||||||
if 'hires_fix' in res_parameters:
|
if 'hires_fix' in res_parameters:
|
||||||
hires_fix_ori = res_parameters.get('hires_fix')
|
res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
|
||||||
hires_fix_info = hires_fix_ori.strip().split(',')
|
|
||||||
if 'upscaler' in hires_fix_info:
|
|
||||||
hires_fix = Txt2ImgV3HiresFix(
|
|
||||||
target_width=int(hires_fix_info[0]),
|
|
||||||
target_height=int(hires_fix_info[1]),
|
|
||||||
strength=float(hires_fix_info[2]),
|
|
||||||
upscaler=hires_fix_info[3].strip()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hires_fix = Txt2ImgV3HiresFix(
|
|
||||||
target_width=int(hires_fix_info[0]),
|
|
||||||
target_height=int(hires_fix_info[1]),
|
|
||||||
strength=float(hires_fix_info[2])
|
|
||||||
)
|
|
||||||
|
|
||||||
res_parameters['hires_fix'] = hires_fix
|
# process refiner
|
||||||
|
if 'refiner_switch_at' in res_parameters:
|
||||||
if 'refiner_switch_at' in res_parameters:
|
res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
|
||||||
refiner = Txt2ImgV3Refiner(
|
del res_parameters['refiner_switch_at']
|
||||||
switch_at=float(res_parameters.get('refiner_switch_at'))
|
|
||||||
)
|
|
||||||
del res_parameters['refiner_switch_at']
|
|
||||||
res_parameters['refiner'] = refiner
|
|
||||||
|
|
||||||
return res_parameters
|
return res_parameters
|
||||||
|
|
||||||
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
|
|
||||||
"""
|
|
||||||
is hit nsfw
|
|
||||||
"""
|
|
||||||
if image.nsfw_detection_result is None:
|
|
||||||
return False
|
|
||||||
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
1556
api/poetry.lock
generated
1556
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
|
@ -153,7 +153,7 @@ langfuse = "^2.36.1"
|
||||||
langsmith = "^0.1.77"
|
langsmith = "^0.1.77"
|
||||||
mailchimp-transactional = "~1.0.50"
|
mailchimp-transactional = "~1.0.50"
|
||||||
markdown = "~3.5.1"
|
markdown = "~3.5.1"
|
||||||
novita-client = "^0.5.6"
|
novita-client = "^0.5.7"
|
||||||
numpy = "~1.26.4"
|
numpy = "~1.26.4"
|
||||||
openai = "~1.29.0"
|
openai = "~1.29.0"
|
||||||
openpyxl = "~3.1.5"
|
openpyxl = "~3.1.5"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user