fix the issue of the refine_switches at param being invalid in the Novita.AI tool (#7485)

This commit is contained in:
Xiao Ley 2024-08-21 15:09:05 +08:00 committed by GitHub
parent 66dfb5c89a
commit 0c99a3d0c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 890 additions and 809 deletions

View File

@ -0,0 +1,73 @@
from novita_client import (
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,
Txt2ImgV3Refiner,
V3TaskImage,
)
class NovitaAiToolBase:
def _extract_loras(self, loras_str: str):
if not loras_str:
return []
loras_ori_list = lora_str.strip().split(';')
result_list = []
for lora_str in loras_ori_list:
lora_info = lora_str.strip().split(',')
lora = Txt2ImgV3LoRA(
model_name=lora_info[0].strip(),
strength=float(lora_info[1]),
)
result_list.append(lora)
return result_list
def _extract_embeddings(self, embeddings_str: str):
if not embeddings_str:
return []
embeddings_ori_list = embeddings_str.strip().split(';')
result_list = []
for embedding_str in embeddings_ori_list:
embedding = Txt2ImgV3Embedding(
model_name=embedding_str.strip()
)
result_list.append(embedding)
return result_list
def _extract_hires_fix(self, hires_fix_str: str):
hires_fix_info = hires_fix_str.strip().split(',')
if 'upscaler' in hires_fix_info:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2]),
upscaler=hires_fix_info[3].strip()
)
else:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2])
)
return hires_fix
def _extract_refiner(self, switch_at: str):
refiner = Txt2ImgV3Refiner(
switch_at=float(switch_at)
)
return refiner
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
"""
is hit nsfw
"""
if image.nsfw_detection_result is None:
return False
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
return True
return False

View File

@ -4,19 +4,15 @@ from typing import Any, Union
from novita_client import (
NovitaClient,
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,
Txt2ImgV3Refiner,
V3TaskImage,
)
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase
from core.tools.tool.builtin_tool import BuiltinTool
class NovitaAiTxt2ImgTool(BuiltinTool):
class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
@ -73,65 +69,19 @@ class NovitaAiTxt2ImgTool(BuiltinTool):
# process loras
if 'loras' in res_parameters:
loras_ori_list = res_parameters.get('loras').strip().split(';')
locals_list = []
for lora_str in loras_ori_list:
lora_info = lora_str.strip().split(',')
lora = Txt2ImgV3LoRA(
model_name=lora_info[0].strip(),
strength=float(lora_info[1]),
)
locals_list.append(lora)
res_parameters['loras'] = locals_list
res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
# process embeddings
if 'embeddings' in res_parameters:
embeddings_ori_list = res_parameters.get('embeddings').strip().split(';')
locals_list = []
for embedding_str in embeddings_ori_list:
embedding = Txt2ImgV3Embedding(
model_name=embedding_str.strip()
)
locals_list.append(embedding)
res_parameters['embeddings'] = locals_list
res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
# process hires_fix
if 'hires_fix' in res_parameters:
hires_fix_ori = res_parameters.get('hires_fix')
hires_fix_info = hires_fix_ori.strip().split(',')
if 'upscaler' in hires_fix_info:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2]),
upscaler=hires_fix_info[3].strip()
)
else:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2])
)
res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
res_parameters['hires_fix'] = hires_fix
if 'refiner_switch_at' in res_parameters:
refiner = Txt2ImgV3Refiner(
switch_at=float(res_parameters.get('refiner_switch_at'))
)
del res_parameters['refiner_switch_at']
res_parameters['refiner'] = refiner
# process refiner
if 'refiner_switch_at' in res_parameters:
res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
del res_parameters['refiner_switch_at']
return res_parameters
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
"""
is hit nsfw
"""
if image.nsfw_detection_result is None:
return False
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
return True
return False

1556
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -153,7 +153,7 @@ langfuse = "^2.36.1"
langsmith = "^0.1.77"
mailchimp-transactional = "~1.0.50"
markdown = "~3.5.1"
novita-client = "^0.5.6"
novita-client = "^0.5.7"
numpy = "~1.26.4"
openai = "~1.29.0"
openpyxl = "~3.1.5"