diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 7fc0da73fb..9779fa71a0 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -583,3 +583,113 @@ SPEECH2TEXT_BASE_MODELS = [ ) ) ] +TTS_BASE_MODELS = [ + AzureBaseModel( + base_model_name='tts-1', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label' + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TTS, + model_properties={ + ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.VOICES: [ + { + 'mode': 'alloy', + 'name': 'Alloy', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'echo', + 'name': 'Echo', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'fable', + 'name': 'Fable', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'onyx', + 'name': 'Onyx', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'nova', + 'name': 'Nova', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'shimmer', + 'name': 'Shimmer', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + ], + ModelPropertyKey.WORD_LIMIT: 120, + ModelPropertyKey.AUDOI_TYPE: 'mp3', + ModelPropertyKey.MAX_WORKERS: 5 + }, + pricing=PriceConfig( + input=0.015, + unit=0.001, + currency='USD', + ) + ) + ), + AzureBaseModel( + base_model_name='tts-1-hd', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label' + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TTS, + model_properties={ + ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.VOICES: [ + { + 'mode': 'alloy', + 'name': 'Alloy', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'echo', + 'name': 'Echo', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'fable', + 'name': 'Fable', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'onyx', + 'name': 'Onyx', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'nova', + 'name': 'Nova', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'shimmer', + 'name': 'Shimmer', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + ], + ModelPropertyKey.WORD_LIMIT: 120, + ModelPropertyKey.AUDOI_TYPE: 'mp3', + ModelPropertyKey.MAX_WORKERS: 5 + }, + pricing=PriceConfig( + input=0.03, + unit=0.001, + currency='USD', + ) + ) + ) +] diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index 6c56ccc920..58800ddee2 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -16,6 +16,7 @@ supported_model_types: - llm - text-embedding - speech2text + - tts configurate_methods: - customizable-model model_credential_schema: @@ -118,6 +119,18 @@ model_credential_schema: show_on: - variable: __model_type value: speech2text + - label: + en_US: tts-1 + value: tts-1 + show_on: + - variable: __model_type + value: tts + - label: + en_US: tts-1-hd + value: tts-1-hd + show_on: + - variable: __model_type + value: tts placeholder: zh_Hans: 在此输入您的模型版本 en_US: Enter your model version diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/__init__.py b/api/core/model_runtime/model_providers/azure_openai/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py new file mode 100644 index 0000000000..585b061afe --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -0,0 +1,174 @@ +import concurrent.futures +import copy +from functools import reduce +from io import BytesIO +from typing import Optional + +from flask import Response, stream_with_context +from openai import AzureOpenAI +from pydub import AudioSegment + +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel +from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI +from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel +from extensions.ext_storage import storage + + +class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): + """ + Model class for OpenAI Speech to text model. + """ + + def _invoke(self, model: str, tenant_id: str, credentials: dict, + content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any: + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :param streaming: output is streaming + :param user: unique user id + :return: text translated to audio file + """ + audio_type = self._get_model_audio_type(model, credentials) + if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: + voice = self._get_model_default_voice(model, credentials) + if streaming: + return Response(stream_with_context(self._tts_invoke_streaming(model=model, + credentials=credentials, + content_text=content_text, + tenant_id=tenant_id, + voice=voice)), + status=200, mimetype=f'audio/{audio_type}') + else: + return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) + + def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: + """ + validate credentials text2speech model + + :param model: model name + :param credentials: model credentials + :param user: unique user id + :return: text translated to audio file + """ + try: + self._tts_invoke( + model=model, + credentials=credentials, + content_text='Hello Dify!', + voice=self._get_model_default_voice(model, credentials), + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response: + """ + _tts_invoke text2speech model + + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + audio_type = self._get_model_audio_type(model, credentials) + word_limit = self._get_model_word_limit(model, credentials) + max_workers = self._get_model_workers_limit(model, credentials) + try: + sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + audio_bytes_list = list() + + # Create a thread pool and map the function to the list of sentences + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice, + credentials=credentials) for sentence in sentences] + for future in futures: + try: + if future.result(): + audio_bytes_list.append(future.result()) + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + if len(audio_bytes_list) > 0: + audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in + audio_bytes_list if audio_bytes] + combined_segment = reduce(lambda x, y: x + y, audio_segments) + buffer: BytesIO = BytesIO() + combined_segment.export(buffer, format=audio_type) + buffer.seek(0) + return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + # Todo: To improve the streaming function + def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str, + voice: str) -> any: + """ + _tts_invoke_streaming text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + # transform credentials to kwargs for model instance + credentials_kwargs = self._to_credential_kwargs(credentials) + if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials): + voice = self._get_model_default_voice(model, credentials) + word_limit = self._get_model_word_limit(model, credentials) + audio_type = self._get_model_audio_type(model, credentials) + tts_file_id = self._get_file_name(content_text) + file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}' + try: + client = AzureOpenAI(**credentials_kwargs) + sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + for sentence in sentences: + response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) + # response.stream_to_file(file_path) + storage.save(file_path, response.read()) + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + def _process_sentence(self, sentence: str, model: str, + voice, credentials: dict): + """ + _tts_invoke openai text2speech model api + + :param model: model name + :param credentials: model credentials + :param voice: model timbre + :param sentence: text content to be translated + :return: text translated to audio file + """ + # transform credentials to kwargs for model instance + credentials_kwargs = self._to_credential_kwargs(credentials) + client = AzureOpenAI(**credentials_kwargs) + response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) + if isinstance(response.read(), bytes): + return response.read() + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + return ai_model_entity.entity + + + @staticmethod + def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: + for ai_model_entity in TTS_BASE_MODELS: + if ai_model_entity.base_model_name == base_model_name: + ai_model_entity_copy = copy.deepcopy(ai_model_entity) + ai_model_entity_copy.entity.model = model + ai_model_entity_copy.entity.label.en_US = model + ai_model_entity_copy.entity.label.zh_Hans = model + return ai_model_entity_copy + + return None diff --git a/api/requirements.txt b/api/requirements.txt index 847903c4f4..7edd95a893 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -11,7 +11,7 @@ flask-cors~=4.0.0 gunicorn~=21.2.0 gevent~=23.9.1 langchain==0.0.250 -openai~=1.3.6 +openai~=1.13.3 tiktoken~=0.5.2 psycopg2-binary~=2.9.6 pycryptodome==3.19.1