From 15f341b655fadf477126aa4cfcc9302b034e72d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Thu, 14 Nov 2024 16:37:15 +0800 Subject: [PATCH] feat: add the audio tool (#10695) --- .../provider/builtin/audio/_assets/icon.svg | 3 + .../tools/provider/builtin/audio/audio.py | 6 ++ .../tools/provider/builtin/audio/audio.yaml | 11 +++ .../tools/provider/builtin/audio/tools/asr.py | 70 +++++++++++++++ .../provider/builtin/audio/tools/asr.yaml | 22 +++++ .../tools/provider/builtin/audio/tools/tts.py | 90 +++++++++++++++++++ .../provider/builtin/audio/tools/tts.yaml | 22 +++++ 7 files changed, 224 insertions(+) create mode 100644 api/core/tools/provider/builtin/audio/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/audio/audio.py create mode 100644 api/core/tools/provider/builtin/audio/audio.yaml create mode 100644 api/core/tools/provider/builtin/audio/tools/asr.py create mode 100644 api/core/tools/provider/builtin/audio/tools/asr.yaml create mode 100644 api/core/tools/provider/builtin/audio/tools/tts.py create mode 100644 api/core/tools/provider/builtin/audio/tools/tts.yaml diff --git a/api/core/tools/provider/builtin/audio/_assets/icon.svg b/api/core/tools/provider/builtin/audio/_assets/icon.svg new file mode 100644 index 0000000000..08cc4ede66 --- /dev/null +++ b/api/core/tools/provider/builtin/audio/_assets/icon.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/audio/audio.py b/api/core/tools/provider/builtin/audio/audio.py new file mode 100644 index 0000000000..1f15386f78 --- /dev/null +++ b/api/core/tools/provider/builtin/audio/audio.py @@ -0,0 +1,6 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class AudioToolProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/tools/provider/builtin/audio/audio.yaml b/api/core/tools/provider/builtin/audio/audio.yaml new file mode 100644 index 0000000000..07db268dac --- /dev/null +++ b/api/core/tools/provider/builtin/audio/audio.yaml @@ -0,0 +1,11 @@ +identity: + author: hjlarry + name: audio + label: + en_US: Audio + description: + en_US: A tool for tts and asr. + zh_Hans: 一个用于文本转语音和语音转文本的工具。 + icon: icon.svg + tags: + - utilities diff --git a/api/core/tools/provider/builtin/audio/tools/asr.py b/api/core/tools/provider/builtin/audio/tools/asr.py new file mode 100644 index 0000000000..d1a3a81881 --- /dev/null +++ b/api/core/tools/provider/builtin/audio/tools/asr.py @@ -0,0 +1,70 @@ +import io +from typing import Any + +from core.file.enums import FileType +from core.file.file_manager import download +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption +from core.tools.tool.builtin_tool import BuiltinTool +from services.model_provider_service import ModelProviderService + + +class ASRTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: + file = tool_parameters.get("audio_file") + if file.type != FileType.AUDIO: + return [self.create_text_message("not a valid audio file")] + audio_binary = io.BytesIO(download(file)) + audio_binary.name = "temp.mp3" + provider, model = tool_parameters.get("model").split("#") + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.runtime.tenant_id, + provider=provider, + model_type=ModelType.SPEECH2TEXT, + model=model, + ) + text = model_instance.invoke_speech2text( + file=audio_binary, + user=user_id, + ) + return [self.create_text_message(text)] + + def get_available_models(self) -> list[tuple[str, str]]: + model_provider_service = ModelProviderService() + models = model_provider_service.get_models_by_model_type( + tenant_id=self.runtime.tenant_id, model_type="speech2text" + ) + items = [] + for provider_model in models: + provider = provider_model.provider + for model in provider_model.models: + items.append((provider, model.model)) + return items + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [] + + options = [] + for provider, model in self.get_available_models(): + option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})")) + options.append(option) + + parameters.append( + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="All available ASR models", + zh_Hans="所有可用的 ASR 模型", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + required=True, + default=options[0].value, + options=options, + ) + ) + return parameters diff --git a/api/core/tools/provider/builtin/audio/tools/asr.yaml b/api/core/tools/provider/builtin/audio/tools/asr.yaml new file mode 100644 index 0000000000..b2c82f8086 --- /dev/null +++ b/api/core/tools/provider/builtin/audio/tools/asr.yaml @@ -0,0 +1,22 @@ +identity: + name: asr + author: hjlarry + label: + en_US: Speech To Text +description: + human: + en_US: Convert audio file to text. + zh_Hans: 将音频文件转换为文本。 + llm: Convert audio file to text. +parameters: + - name: audio_file + type: file + required: true + label: + en_US: Audio File + zh_Hans: 音频文件 + human_description: + en_US: The audio file to be converted. + zh_Hans: 要转换的音频文件。 + llm_description: The audio file to be converted. + form: llm diff --git a/api/core/tools/provider/builtin/audio/tools/tts.py b/api/core/tools/provider/builtin/audio/tools/tts.py new file mode 100644 index 0000000000..fb7213dee7 --- /dev/null +++ b/api/core/tools/provider/builtin/audio/tools/tts.py @@ -0,0 +1,90 @@ +import io +from typing import Any + +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption +from core.tools.tool.builtin_tool import BuiltinTool +from services.model_provider_service import ModelProviderService + + +class TTSTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: + provider, model = tool_parameters.get("model").split("#") + voice = tool_parameters.get(f"voice#{provider}#{model}") + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.runtime.tenant_id, + provider=provider, + model_type=ModelType.TTS, + model=model, + ) + tts = model_instance.invoke_tts( + content_text=tool_parameters.get("text"), + user=user_id, + tenant_id=self.runtime.tenant_id, + voice=voice, + ) + buffer = io.BytesIO() + for chunk in tts: + buffer.write(chunk) + + wav_bytes = buffer.getvalue() + return [ + self.create_text_message("Audio generated successfully"), + self.create_blob_message( + blob=wav_bytes, + meta={"mime_type": "audio/x-wav"}, + save_as=self.VariableKey.AUDIO, + ), + ] + + def get_available_models(self) -> list[tuple[str, str, list[Any]]]: + model_provider_service = ModelProviderService() + models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts") + items = [] + for provider_model in models: + provider = provider_model.provider + for model in provider_model.models: + voices = model.model_properties.get(ModelPropertyKey.VOICES, []) + items.append((provider, model.model, voices)) + return items + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [] + + options = [] + for provider, model, voices in self.get_available_models(): + option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})")) + options.append(option) + parameters.append( + ToolParameter( + name=f"voice#{provider}#{model}", + label=I18nObject(en_US=f"Voice of {model}({provider})"), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + options=[ + ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name"))) + for voice in voices + ], + ) + ) + + parameters.insert( + 0, + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="All available TTS models", + zh_Hans="所有可用的 TTS 模型", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + required=True, + default=options[0].value, + options=options, + ), + ) + return parameters diff --git a/api/core/tools/provider/builtin/audio/tools/tts.yaml b/api/core/tools/provider/builtin/audio/tools/tts.yaml new file mode 100644 index 0000000000..36f42bd689 --- /dev/null +++ b/api/core/tools/provider/builtin/audio/tools/tts.yaml @@ -0,0 +1,22 @@ +identity: + name: tts + author: hjlarry + label: + en_US: Text To Speech +description: + human: + en_US: Convert text to audio file. + zh_Hans: 将文本转换为音频文件。 + llm: Convert text to audio file. +parameters: + - name: text + type: string + required: true + label: + en_US: Text + zh_Hans: 文本 + human_description: + en_US: The text to be converted. + zh_Hans: 要转换的文本。 + llm_description: The text to be converted. + form: llm