From 292220c596350fde6802b4946835b063c20de0a4 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Wed, 11 Sep 2024 16:40:52 +0800 Subject: [PATCH] chore: apply pep8-naming rules for naming convention (#8261) --- api/controllers/console/app/conversation.py | 10 +- api/controllers/console/app/statistic.py | 34 +++--- .../console/app/workflow_statistic.py | 18 ++-- api/controllers/console/auth/activate.py | 4 +- api/controllers/console/init_validate.py | 4 +- api/controllers/console/setup.py | 4 +- .../app/apps/advanced_chat/app_generator.py | 6 +- .../app_generator_tts_publisher.py | 8 +- api/core/app/apps/advanced_chat/app_runner.py | 4 +- .../advanced_chat/generate_task_pipeline.py | 8 +- api/core/app/apps/agent_chat/app_generator.py | 4 +- api/core/app/apps/agent_chat/app_runner.py | 4 +- api/core/app/apps/base_app_queue_manager.py | 2 +- api/core/app/apps/chat/app_generator.py | 4 +- api/core/app/apps/chat/app_runner.py | 4 +- api/core/app/apps/completion/app_generator.py | 4 +- api/core/app/apps/completion/app_runner.py | 4 +- .../app/apps/message_based_app_generator.py | 4 +- .../apps/message_based_app_queue_manager.py | 4 +- api/core/app/apps/workflow/app_generator.py | 6 +- .../app/apps/workflow/app_queue_manager.py | 4 +- .../apps/workflow/generate_task_pipeline.py | 8 +- api/core/app/segments/segments.py | 1 + .../easy_ui_based_generate_task_pipeline.py | 8 +- .../helper/code_executor/code_executor.py | 18 ++-- api/core/indexing_runner.py | 22 ++-- .../llm_generator/output_parser/errors.py | 2 +- .../output_parser/rule_config_generator.py | 4 +- .../baichuan/llm/baichuan_turbo.py | 4 +- .../baichuan/llm/baichuan_turbo_errors.py | 2 +- .../model_providers/baichuan/llm/llm.py | 4 +- .../baichuan/text_embedding/text_embedding.py | 6 +- .../openai_api_compatible/_common.py | 2 +- .../openai_api_compatible/llm/llm.py | 4 +- .../speech2text/speech2text.py | 4 +- .../text_embedding/text_embedding.py | 4 +- .../text_embedding/text_embedding.py | 4 +- .../volcengine_maas/legacy/client.py | 4 +- .../volcengine_maas/legacy/errors.py | 100 +++++++++--------- .../legacy/volc_sdk/__init__.py | 4 +- .../volcengine_maas/legacy/volc_sdk/maas.py | 14 +-- .../volcengine_maas/llm/llm.py | 4 +- .../text_embedding/text_embedding.py | 4 +- .../model_providers/wenxin/wenxin_errors.py | 4 +- api/core/moderation/base.py | 4 +- api/core/moderation/input_moderation.py | 4 +- api/core/ops/entities/config_entity.py | 2 + api/core/ops/entities/trace_entity.py | 1 + .../entities/langfuse_trace_entity.py | 4 + .../entities/langsmith_trace_entity.py | 2 + .../vdb/elasticsearch/elasticsearch_vector.py | 1 + .../datasource/vdb/milvus/milvus_vector.py | 1 + .../vdb/opensearch/opensearch_vector.py | 1 + .../rag/datasource/vdb/oracle/oraclevector.py | 1 + .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 1 + .../rag/datasource/vdb/pgvector/pgvector.py | 1 + .../rag/datasource/vdb/relyt/relyt_vector.py | 1 + .../datasource/vdb/tidb_vector/tidb_vector.py | 1 + .../vdb/weaviate/weaviate_vector.py | 1 + .../builtin/azuredalle/tools/dalle3.py | 2 +- .../provider/builtin/dalle/tools/dalle2.py | 2 +- .../provider/builtin/dalle/tools/dalle3.py | 2 +- .../novitaai/tools/novitaai_createtile.py | 2 +- .../novitaai/tools/novitaai_txt2img.py | 2 +- .../builtin/qrcode/tools/qrcode_generator.py | 2 +- .../builtin/siliconflow/tools/flux.py | 2 +- .../siliconflow/tools/stable_diffusion.py | 2 +- .../spark/tools/spark_img_generation.py | 14 +-- .../builtin/stability/tools/text2image.py | 2 +- .../stablediffusion/tools/stable_diffusion.py | 4 +- .../builtin/tianditu/tools/staticmap.py | 2 +- .../builtin/vectorizer/tools/vectorizer.py | 2 +- api/core/tools/tool/tool.py | 6 +- .../workflow/graph_engine/graph_engine.py | 4 +- api/core/workflow/nodes/code/code_node.py | 4 +- .../template_transform_node.py | 4 +- api/core/workflow/workflow_entry.py | 4 +- .../event_handlers/create_document_index.py | 4 +- api/extensions/storage/google_storage.py | 6 +- api/libs/gmpy2_pkcs10aep_cipher.py | 4 +- api/libs/helper.py | 6 +- api/libs/json_in_md_parser.py | 6 +- api/pyproject.toml | 11 +- api/services/account_service.py | 10 +- api/services/errors/account.py | 4 +- api/tasks/document_indexing_sync_task.py | 4 +- api/tasks/document_indexing_task.py | 4 +- api/tasks/document_indexing_update_task.py | 4 +- api/tasks/duplicate_document_indexing_task.py | 4 +- api/tasks/recover_document_indexing_task.py | 4 +- .../model_runtime/__mock/huggingface_tei.py | 1 + .../integration_tests/tools/__mock/http.py | 1 + .../vdb/__mock/tcvectordb.py | 4 +- .../workflow/nodes/__mock/http.py | 1 + .../nodes/code_executor/test_code_executor.py | 4 +- 95 files changed, 287 insertions(+), 258 deletions(-) diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 46c0b22993..df7bd352af 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -20,7 +20,7 @@ from fields.conversation_fields import ( conversation_pagination_fields, conversation_with_summary_pagination_fields, ) -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation @@ -36,8 +36,8 @@ class CompletionConversationApi(Resource): raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("keyword", type=str, location="args") - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument( "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" ) @@ -143,8 +143,8 @@ class ChatConversationApi(Resource): raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("keyword", type=str, location="args") - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument( "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" ) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 81826a20d0..4806b02b55 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode @@ -25,8 +25,8 @@ class DailyMessageStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -79,8 +79,8 @@ class DailyConversationStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -133,8 +133,8 @@ class DailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -187,8 +187,8 @@ class DailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -245,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, @@ -307,8 +307,8 @@ class UserSatisfactionRateStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -369,8 +369,8 @@ class AverageResponseTimeStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -425,8 +425,8 @@ class TokensPerSecondStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index db2f683589..942271a634 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode from models.workflow import WorkflowRunTriggeredFrom @@ -26,8 +26,8 @@ class WorkflowDailyRunsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -86,8 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -146,8 +146,8 @@ class WorkflowDailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -213,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 8ba6b53e7e..f3198dfc1d 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -8,7 +8,7 @@ from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db -from libs.helper import email, str_len, timezone +from libs.helper import StrLen, email, timezone from libs.password import hash_password, valid_password from models.account import AccountStatus from services.account_service import RegisterService @@ -37,7 +37,7 @@ class ActivateApi(Resource): parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") parser.add_argument("email", type=email, required=False, nullable=True, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument( "interface_language", type=supported_language, required=True, nullable=False, location="json" diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 7d3ae677ee..ae759bb752 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -4,7 +4,7 @@ from flask import session from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import str_len +from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService @@ -28,7 +28,7 @@ class InitValidateAPI(Resource): raise AlreadySetupError() parser = reqparse.RequestParser() - parser.add_argument("password", type=str_len(30), required=True, location="json") + parser.add_argument("password", type=StrLen(30), required=True, location="json") input_password = parser.parse_args()["password"] if input_password != os.environ.get("INIT_PASSWORD"): diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 827695e00f..46b4ef5d87 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -4,7 +4,7 @@ from flask import request from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import email, get_remote_ip, str_len +from libs.helper import StrLen, email, get_remote_ip from libs.password import valid_password from models.model import DifySetup from services.account_service import RegisterService, TenantService @@ -40,7 +40,7 @@ class SetupApi(Resource): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("name", type=str_len(30), required=True, location="json") + parser.add_argument("name", type=StrLen(30), required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 1277dcebc5..88e1256ed5 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom @@ -293,7 +293,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): ) runner.run() - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( @@ -349,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index d9fc599542..18b115dfe4 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -21,7 +21,7 @@ class AudioTrunk: self.status = status -def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( @@ -81,7 +81,7 @@ class AppGeneratorTTSPublisher: if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: futures_result = self.executor.submit( - _invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice + _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice ) future_queue.put(futures_result) break @@ -97,7 +97,7 @@ class AppGeneratorTTSPublisher: self.MAX_SENTENCE += 1 text_content = "".join(sentence_arr) futures_result = self.executor.submit( - _invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice + _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice ) future_queue.put(futures_result) if text_tmp: @@ -110,7 +110,7 @@ class AppGeneratorTTSPublisher: break future_queue.put(None) - def checkAndGetAudio(self) -> AudioTrunk | None: + def check_and_get_audio(self) -> AudioTrunk | None: try: if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 90f547b0f2..c4cdba6441 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -19,7 +19,7 @@ from core.app.entities.queue_entities import ( QueueStopEvent, QueueTextChunkEvent, ) -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool @@ -217,7 +217,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): query=query, message_id=message_id, ) - except ModerationException as e: + except ModerationError as e: self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) return True diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 8f65a670c3..94206a1b1c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -179,10 +179,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc stream_response=stream_response, ) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -204,7 +204,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -217,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc try: if not tts_publisher: break - audio_trunk = tts_publisher.checkAndGetAudio() + audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 7ba6bbab94..abf8a332ab 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom @@ -205,7 +205,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 6b676b0353..45b1bf0093 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -15,7 +15,7 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.tools.entities.tool_entities import ToolRuntimeVariablePool from extensions.ext_database import db from models.model import App, Conversation, Message, MessageAgentThought @@ -103,7 +103,7 @@ class AgentChatAppRunner(AppRunner): query=query, message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index df972756d5..f3c3199354 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -171,5 +171,5 @@ class AppQueueManager: ) -class GenerateTaskStoppedException(Exception): +class GenerateTaskStoppedError(Exception): pass diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 15c7140308..032556ec4c 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -10,7 +10,7 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter @@ -205,7 +205,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index bd90586825..425f1ab7ef 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, Conversation, Message @@ -98,7 +98,7 @@ class ChatAppRunner(AppRunner): query=query, message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index d7301224e8..7fce296f2b 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -10,7 +10,7 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter @@ -185,7 +185,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): queue_manager=queue_manager, message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index da49c8701f..908d74ff53 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, Message @@ -79,7 +79,7 @@ class CompletionAppRunner(AppRunner): query=query, message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index a91d48d246..f629c5c8b7 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -8,7 +8,7 @@ from sqlalchemy import and_ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, @@ -77,7 +77,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 7f259db6eb..363c3c82bb 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,4 +1,4 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, @@ -53,4 +53,4 @@ class MessageBasedAppQueueManager(AppQueueManager): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index c685008577..57a77591a0 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -12,7 +12,7 @@ from pydantic import ValidationError import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner @@ -253,7 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator): ) runner.run() - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( @@ -302,7 +302,7 @@ class WorkflowAppGenerator(BaseAppGenerator): return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index c9f501cd5e..76371f800b 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,4 +1,4 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, @@ -39,4 +39,4 @@ class WorkflowAppQueueManager(AppQueueManager): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 215d02bddd..93edf8e0e8 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -162,10 +162,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -187,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -199,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa try: if not tts_publisher: break - audio_trunk = tts_publisher.checkAndGetAudio() + audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index b71924b2d3..b26b3c8291 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -15,6 +15,7 @@ class Segment(BaseModel): value: Any @field_validator("value_type") + @classmethod def validate_value_type(cls, value): """ This validator checks if the provided value is equal to the default value of the 'value_type' field. diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 61e920845c..659503301e 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -201,10 +201,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan stream_response=stream_response, ) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if publisher is None: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) @@ -225,7 +225,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id) + audio_response = self._listen_audio_msg(publisher, task_id) if audio_response: yield audio_response else: @@ -237,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: if publisher is None: break - audio = publisher.checkAndGetAudio() + audio = publisher.check_and_get_audio() if audio is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 4a80a3ffe9..7ee6e63817 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -16,7 +16,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer logger = logging.getLogger(__name__) -class CodeExecutionException(Exception): +class CodeExecutionError(Exception): pass @@ -86,15 +86,15 @@ class CodeExecutor: ), ) if response.status_code == 503: - raise CodeExecutionException("Code execution service is unavailable") + raise CodeExecutionError("Code execution service is unavailable") elif response.status_code != 200: raise Exception( f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running" ) - except CodeExecutionException as e: + except CodeExecutionError as e: raise e except Exception as e: - raise CodeExecutionException( + raise CodeExecutionError( "Failed to execute code, which is likely a network issue," " please check if the sandbox service is running." f" ( Error: {str(e)} )" @@ -103,15 +103,15 @@ class CodeExecutor: try: response = response.json() except: - raise CodeExecutionException("Failed to parse response") + raise CodeExecutionError("Failed to parse response") if (code := response.get("code")) != 0: - raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") + raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}") response = CodeExecutionResponse(**response) if response.data.error: - raise CodeExecutionException(response.data.error) + raise CodeExecutionError(response.data.error) return response.data.stdout or "" @@ -126,13 +126,13 @@ class CodeExecutor: """ template_transformer = cls.code_template_transformers.get(language) if not template_transformer: - raise CodeExecutionException(f"Unsupported language {language}") + raise CodeExecutionError(f"Unsupported language {language}") runner, preload = template_transformer.transform_caller(code, inputs) try: response = cls.execute_code(language, preload, runner) - except CodeExecutionException as e: + except CodeExecutionError as e: raise e return template_transformer.transform_response(response) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index b6968e46cd..eeb1dbfda0 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -78,8 +78,8 @@ class IndexingRunner: dataset_document=dataset_document, documents=documents, ) - except DocumentIsPausedException: - raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) @@ -134,8 +134,8 @@ class IndexingRunner: self._load( index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) - except DocumentIsPausedException: - raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) @@ -192,8 +192,8 @@ class IndexingRunner: self._load( index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) - except DocumentIsPausedException: - raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) @@ -756,7 +756,7 @@ class IndexingRunner: indexing_cache_key = "document_{}_is_paused".format(document_id) result = redis_client.get(indexing_cache_key) if result: - raise DocumentIsPausedException() + raise DocumentIsPausedError() @staticmethod def _update_document_index_status( @@ -767,10 +767,10 @@ class IndexingRunner: """ count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() if count > 0: - raise DocumentIsPausedException() + raise DocumentIsPausedError() document = DatasetDocument.query.filter_by(id=document_id).first() if not document: - raise DocumentIsDeletedPausedException() + raise DocumentIsDeletedPausedError() update_params = {DatasetDocument.indexing_status: after_indexing_status} @@ -875,9 +875,9 @@ class IndexingRunner: pass -class DocumentIsPausedException(Exception): +class DocumentIsPausedError(Exception): pass -class DocumentIsDeletedPausedException(Exception): +class DocumentIsDeletedPausedError(Exception): pass diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py index 6a60f8de80..1e743f1757 100644 --- a/api/core/llm_generator/output_parser/errors.py +++ b/api/core/llm_generator/output_parser/errors.py @@ -1,2 +1,2 @@ -class OutputParserException(Exception): +class OutputParserError(Exception): pass diff --git a/api/core/llm_generator/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py index b6932698cb..0c7683b16d 100644 --- a/api/core/llm_generator/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -1,6 +1,6 @@ from typing import Any -from core.llm_generator.output_parser.errors import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.prompts import ( RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, @@ -29,4 +29,4 @@ class RuleConfigGeneratorOutputParser: raise ValueError("Expected 'opening_statement' to be a str.") return parsed except Exception as e: - raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") + raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index 6e181ac5f8..39f867118b 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -7,7 +7,7 @@ from requests import post from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -124,7 +124,7 @@ class BaichuanModel: if err == "invalid_api_key": raise InvalidAPIKeyError(msg) elif err == "insufficient_quota": - raise InsufficientAccountBalance(msg) + raise InsufficientAccountBalanceError(msg) elif err == "invalid_authentication": raise InvalidAuthenticationError(msg) elif err == "invalid_request_error": diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py index 4e56e58d7e..309b5cf413 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py @@ -10,7 +10,7 @@ class RateLimitReachedError(Exception): pass -class InsufficientAccountBalance(Exception): +class InsufficientAccountBalanceError(Exception): pass diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index 3291fe2b2e..91a14bf100 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -29,7 +29,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import B from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -289,7 +289,7 @@ class BaichuanLanguageModel(LargeLanguageModel): InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], InvokeBadRequestError: [BadRequestError, KeyError], diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index b7276fabb5..779dfbb608 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -109,7 +109,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): if err == "invalid_api_key": raise InvalidAPIKeyError(msg) elif err == "insufficient_quota": - raise InsufficientAccountBalance(msg) + raise InsufficientAccountBalanceError(msg) elif err == "invalid_authentication": raise InvalidAuthenticationError(msg) elif err and "rate" in err: @@ -166,7 +166,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], InvokeBadRequestError: [BadRequestError, KeyError], diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index 257dffa30d..1234e44f80 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -10,7 +10,7 @@ from core.model_runtime.errors.invoke import ( ) -class _CommonOAI_API_Compat: +class _CommonOaiApiCompat: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 75929af590..24317b488c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -35,13 +35,13 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat from core.model_runtime.utils import helper logger = logging.getLogger(__name__) -class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): +class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): """ Model class for OpenAI large language model. """ diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index 2e8b4ddd72..405096578c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -6,10 +6,10 @@ import requests from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): +class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel): """ Model class for OpenAI Compatible Speech to text model. """ diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index ab358cf70a..e83cfdf873 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): +class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): """ Model class for an OpenAI API-compatible text embedding model. """ diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index d0522233e3..b62a2d2aaf 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): +class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): """ Model class for an OpenAI API-compatible text embedding model. """ diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py index 025b1ed6d2..266f1216f8 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error -from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService class MaaSClient(MaasService): @@ -106,7 +106,7 @@ class MaaSClient(MaasService): def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: try: resp = fn() - except MaasException as e: + except MaasError as e: raise wrap_error(e) return resp diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 8b9c346265..91dbe21a61 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -1,144 +1,144 @@ -from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError -class ClientSDKRequestError(MaasException): +class ClientSDKRequestError(MaasError): pass -class SignatureDoesNotMatch(MaasException): +class SignatureDoesNotMatchError(MaasError): pass -class RequestTimeout(MaasException): +class RequestTimeoutError(MaasError): pass -class ServiceConnectionTimeout(MaasException): +class ServiceConnectionTimeoutError(MaasError): pass -class MissingAuthenticationHeader(MaasException): +class MissingAuthenticationHeaderError(MaasError): pass -class AuthenticationHeaderIsInvalid(MaasException): +class AuthenticationHeaderIsInvalidError(MaasError): pass -class InternalServiceError(MaasException): +class InternalServiceError(MaasError): pass -class MissingParameter(MaasException): +class MissingParameterError(MaasError): pass -class InvalidParameter(MaasException): +class InvalidParameterError(MaasError): pass -class AuthenticationExpire(MaasException): +class AuthenticationExpireError(MaasError): pass -class EndpointIsInvalid(MaasException): +class EndpointIsInvalidError(MaasError): pass -class EndpointIsNotEnable(MaasException): +class EndpointIsNotEnableError(MaasError): pass -class ModelNotSupportStreamMode(MaasException): +class ModelNotSupportStreamModeError(MaasError): pass -class ReqTextExistRisk(MaasException): +class ReqTextExistRiskError(MaasError): pass -class RespTextExistRisk(MaasException): +class RespTextExistRiskError(MaasError): pass -class EndpointRateLimitExceeded(MaasException): +class EndpointRateLimitExceededError(MaasError): pass -class ServiceConnectionRefused(MaasException): +class ServiceConnectionRefusedError(MaasError): pass -class ServiceConnectionClosed(MaasException): +class ServiceConnectionClosedError(MaasError): pass -class UnauthorizedUserForEndpoint(MaasException): +class UnauthorizedUserForEndpointError(MaasError): pass -class InvalidEndpointWithNoURL(MaasException): +class InvalidEndpointWithNoURLError(MaasError): pass -class EndpointAccountRpmRateLimitExceeded(MaasException): +class EndpointAccountRpmRateLimitExceededError(MaasError): pass -class EndpointAccountTpmRateLimitExceeded(MaasException): +class EndpointAccountTpmRateLimitExceededError(MaasError): pass -class ServiceResourceWaitQueueFull(MaasException): +class ServiceResourceWaitQueueFullError(MaasError): pass -class EndpointIsPending(MaasException): +class EndpointIsPendingError(MaasError): pass -class ServiceNotOpen(MaasException): +class ServiceNotOpenError(MaasError): pass AuthErrors = { - "SignatureDoesNotMatch": SignatureDoesNotMatch, - "MissingAuthenticationHeader": MissingAuthenticationHeader, - "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalid, - "AuthenticationExpire": AuthenticationExpire, - "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpoint, + "SignatureDoesNotMatch": SignatureDoesNotMatchError, + "MissingAuthenticationHeader": MissingAuthenticationHeaderError, + "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError, + "AuthenticationExpire": AuthenticationExpireError, + "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError, } BadRequestErrors = { - "MissingParameter": MissingParameter, - "InvalidParameter": InvalidParameter, - "EndpointIsInvalid": EndpointIsInvalid, - "EndpointIsNotEnable": EndpointIsNotEnable, - "ModelNotSupportStreamMode": ModelNotSupportStreamMode, - "ReqTextExistRisk": ReqTextExistRisk, - "RespTextExistRisk": RespTextExistRisk, - "InvalidEndpointWithNoURL": InvalidEndpointWithNoURL, - "ServiceNotOpen": ServiceNotOpen, + "MissingParameter": MissingParameterError, + "InvalidParameter": InvalidParameterError, + "EndpointIsInvalid": EndpointIsInvalidError, + "EndpointIsNotEnable": EndpointIsNotEnableError, + "ModelNotSupportStreamMode": ModelNotSupportStreamModeError, + "ReqTextExistRisk": ReqTextExistRiskError, + "RespTextExistRisk": RespTextExistRiskError, + "InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError, + "ServiceNotOpen": ServiceNotOpenError, } RateLimitErrors = { - "EndpointRateLimitExceeded": EndpointRateLimitExceeded, - "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceeded, - "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceeded, + "EndpointRateLimitExceeded": EndpointRateLimitExceededError, + "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError, + "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError, } ServerUnavailableErrors = { "InternalServiceError": InternalServiceError, - "EndpointIsPending": EndpointIsPending, - "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFull, + "EndpointIsPending": EndpointIsPendingError, + "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError, } ConnectionErrors = { "ClientSDKRequestError": ClientSDKRequestError, - "RequestTimeout": RequestTimeout, - "ServiceConnectionTimeout": ServiceConnectionTimeout, - "ServiceConnectionRefused": ServiceConnectionRefused, - "ServiceConnectionClosed": ServiceConnectionClosed, + "RequestTimeout": RequestTimeoutError, + "ServiceConnectionTimeout": ServiceConnectionTimeoutError, + "ServiceConnectionRefused": ServiceConnectionRefusedError, + "ServiceConnectionClosed": ServiceConnectionClosedError, } ErrorCodeMap = { @@ -150,7 +150,7 @@ ErrorCodeMap = { } -def wrap_error(e: MaasException) -> Exception: +def wrap_error(e: MaasError) -> Exception: if ErrorCodeMap.get(e.code): return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py index 53f320736b..8b3eb157be 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py @@ -1,4 +1,4 @@ from .common import ChatRole -from .maas import MaasException, MaasService +from .maas import MaasError, MaasService -__all__ = ["MaasService", "ChatRole", "MaasException"] +__all__ = ["MaasService", "ChatRole", "MaasError"] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py index 01f15aec24..29c5c3c2d2 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py @@ -63,7 +63,7 @@ class MaasService(Service): raise if res.error is not None and res.error.code_n != 0: - raise MaasException( + raise MaasError( res.error.code_n, res.error.code, res.error.message, @@ -72,7 +72,7 @@ class MaasService(Service): yield res return iter_fn() - except MaasException: + except MaasError: raise except Exception as e: raise new_client_sdk_request_error(str(e)) @@ -94,7 +94,7 @@ class MaasService(Service): resp["req_id"] = req_id return resp - except MaasException as e: + except MaasError as e: raise e except Exception as e: raise new_client_sdk_request_error(str(e), req_id) @@ -147,14 +147,14 @@ class MaasService(Service): raise new_client_sdk_request_error(raw, req_id) if resp.error: - raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, req_id) + raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id) else: raise new_client_sdk_request_error(resp, req_id) return res -class MaasException(Exception): +class MaasError(Exception): def __init__(self, code_n, code, message, req_id): self.code_n = code_n self.code = code @@ -172,7 +172,7 @@ class MaasException(Exception): def new_client_sdk_request_error(raw, req_id=""): - return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) + return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) class BinaryResponseContent: @@ -192,7 +192,7 @@ class BinaryResponseContent: if len(error_bytes) > 0: resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id) - raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) + raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) def iter_bytes(self) -> Iterator[bytes]: yield from self.response diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 98409ab872..c25851fc45 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -35,7 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, - MaasException, + MaasError, RateLimitErrors, ServerUnavailableErrors, ) @@ -85,7 +85,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): }, [UserPromptMessage(content="ping\nAnswer: ")], ) - except MaasException as e: + except MaasError as e: raise CredentialsValidateFailedError(e.message) @staticmethod diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 3cdcd2740c..9cba2cb879 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -28,7 +28,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, - MaasException, + MaasError, RateLimitErrors, ServerUnavailableErrors, ) @@ -111,7 +111,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): def _validate_credentials_v2(self, model: str, credentials: dict) -> None: try: self._invoke(model=model, credentials=credentials, texts=["ping"]) - except MaasException as e: + except MaasError as e: raise CredentialsValidateFailedError(e.message) def _validate_credentials_v3(self, model: str, credentials: dict) -> None: diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py index f2e2248680..bd074e0477 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py @@ -23,7 +23,7 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], InvokeBadRequestError: [BadRequestError, KeyError], @@ -42,7 +42,7 @@ class RateLimitReachedError(Exception): pass -class InsufficientAccountBalance(Exception): +class InsufficientAccountBalanceError(Exception): pass diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 4b91f20184..60898d5547 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -76,7 +76,7 @@ class Moderation(Extensible, ABC): raise NotImplementedError @classmethod - def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None: + def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None: # inputs_config inputs_config = config.get("inputs_config") if not isinstance(inputs_config, dict): @@ -111,5 +111,5 @@ class Moderation(Extensible, ABC): raise ValueError("outputs_config.preset_response must be less than 100 characters") -class ModerationException(Exception): +class ModerationError(Exception): pass diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 336c16eecf..46d3963bd0 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -2,7 +2,7 @@ import logging from typing import Optional from core.app.app_config.entities import AppConfig -from core.moderation.base import ModerationAction, ModerationException +from core.moderation.base import ModerationAction, ModerationError from core.moderation.factory import ModerationFactory from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask @@ -61,7 +61,7 @@ class InputModeration: return False, inputs, query if moderation_result.action == ModerationAction.DIRECT_OUTPUT: - raise ModerationException(moderation_result.preset_response) + raise ModerationError(moderation_result.preset_response) elif moderation_result.action == ModerationAction.OVERRIDDEN: inputs = moderation_result.inputs query = moderation_result.query diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 0ab2139a88..5c79867571 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -26,6 +26,7 @@ class LangfuseConfig(BaseTracingConfig): host: str = "https://api.langfuse.com" @field_validator("host") + @classmethod def set_value(cls, v, info: ValidationInfo): if v is None or v == "": v = "https://api.langfuse.com" @@ -45,6 +46,7 @@ class LangSmithConfig(BaseTracingConfig): endpoint: str = "https://api.smith.langchain.com" @field_validator("endpoint") + @classmethod def set_value(cls, v, info: ValidationInfo): if v is None or v == "": v = "https://api.smith.langchain.com" diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index a3ce27d5d4..f27a0af6e0 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -15,6 +15,7 @@ class BaseTraceInfo(BaseModel): metadata: dict[str, Any] @field_validator("inputs", "outputs") + @classmethod def ensure_type(cls, v): if v is None: return None diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index af7661f0af..447b799f1f 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel): ) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel): ) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -196,6 +198,7 @@ class GenerationUsage(BaseModel): totalCost: Optional[float] = None @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel): model_config = ConfigDict(protected_namespaces=()) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 8cbf162bf2..05c932fb99 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -51,6 +51,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") @field_validator("inputs", "outputs") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name values = info.data @@ -115,6 +116,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): return v return v + @classmethod @field_validator("start_time", "end_time") def format_time(cls, v, info: ValidationInfo): if not isinstance(v, datetime): diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 76c808f76e..f13723b51f 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -27,6 +27,7 @@ class ElasticSearchConfig(BaseModel): password: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config HOST is required") diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 1d08046641..d6d7136282 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -28,6 +28,7 @@ class MilvusConfig(BaseModel): database: str = "default" @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values.get("uri"): raise ValueError("config MILVUS_URI is required") diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index ecd7e0271c..7c0f620956 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -29,6 +29,7 @@ class OpenSearchConfig(BaseModel): secure: bool = False @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values.get("host"): raise ValueError("config OPENSEARCH_HOST is required") diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index eb2e3e0a8c..06c20ceb5f 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -32,6 +32,7 @@ class OracleVectorConfig(BaseModel): database: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config ORACLE_HOST is required") diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index b778582e8a..24b391d63a 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -32,6 +32,7 @@ class PgvectoRSConfig(BaseModel): database: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config PGVECTO_RS_HOST is required") diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index b01cd91e07..38dfd24b56 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -25,6 +25,7 @@ class PGVectorConfig(BaseModel): database: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config PGVECTOR_HOST is required") diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index d8e4ff628c..0c9d3b343d 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -34,6 +34,7 @@ class RelytConfig(BaseModel): database: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config RELYT_HOST is required") diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 0e4b3f67a1..e1ac9d596c 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -29,6 +29,7 @@ class TiDBVectorConfig(BaseModel): program_name: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config TIDB_VECTOR_HOST is required") diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 750172b015..ca1123c6a0 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -23,6 +23,7 @@ class WeaviateConfig(BaseModel): batch_size: int = 100 @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["endpoint"]: raise ValueError("config WEAVIATE_ENDPOINT is required") diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 09f30a59d6..7462824be1 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -69,7 +69,7 @@ class DallE3Tool(BuiltinTool): self.create_blob_message( blob=b64decode(image.b64_json), meta={"mime_type": "image/png"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}")) diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py index ac7e394911..fbd7397292 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -59,7 +59,7 @@ class DallE2Tool(BuiltinTool): self.create_blob_message( blob=b64decode(image.b64_json), meta={"mime_type": "image/png"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index 2d62cf608f..bcfa2212b6 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool): for image in response.data: mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) blob_message = self.create_blob_message( - blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value ) result.append(blob_message) return result diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py index f76587bea1..0b4f2edff3 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py @@ -34,7 +34,7 @@ class NovitaAiCreateTileTool(BuiltinTool): self.create_blob_message( blob=b64decode(client_result.image_file), meta={"mime_type": f"image/{client_result.image_type}"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index 9632c163cf..9c61eab9f9 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -40,7 +40,7 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): self.create_blob_message( blob=b64decode(image_encoded), meta={"mime_type": f"image/{image.image_type}"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index 8aefc65131..cac59f76d8 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -46,7 +46,7 @@ class QRCodeGeneratorTool(BuiltinTool): image = self._generate_qrcode(content, border, error_correction) image_bytes = self._image_to_byte_array(image) return self.create_blob_message( - blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) except Exception: logging.exception(f"Failed to generate QR code for content: {content}") diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.py b/api/core/tools/provider/builtin/siliconflow/tools/flux.py index 5fa9926484..1b846624bd 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/flux.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.py @@ -32,5 +32,5 @@ class FluxTool(BuiltinTool): res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py index e7c3c28d7b..d6a0b03d1b 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py @@ -41,5 +41,5 @@ class StableDiffusionTool(BuiltinTool): res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py index a6f5570af2..81d9e8d941 100644 --- a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py @@ -15,16 +15,16 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -class AssembleHeaderException(Exception): +class AssembleHeaderError(Exception): def __init__(self, msg): self.message = msg class Url: - def __init__(this, host, path, schema): - this.host = host - this.path = path - this.schema = schema + def __init__(self, host, path, schema): + self.host = host + self.path = path + self.schema = schema # calculate sha256 and encode to base64 @@ -41,7 +41,7 @@ def parse_url(request_url): schema = request_url[: stidx + 3] edidx = host.index("/") if edidx <= 0: - raise AssembleHeaderException("invalid request url:" + request_url) + raise AssembleHeaderError("invalid request url:" + request_url) path = host[edidx:] host = host[:edidx] u = Url(host, path, schema) @@ -115,7 +115,7 @@ class SparkImgGeneratorTool(BuiltinTool): self.create_blob_message( blob=b64decode(image["base64_image"]), meta={"mime_type": "image/png"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) return result diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py index c33e3bd78f..12b6cc3352 100644 --- a/api/core/tools/provider/builtin/stability/tools/text2image.py +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -52,5 +52,5 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): raise Exception(response.text) return self.create_blob_message( - blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index c31e178067..46137886bd 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -260,7 +260,7 @@ class StableDiffusionTool(BuiltinTool): image = response.json()["images"][0] return self.create_blob_message( - blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) except Exception as e: @@ -294,7 +294,7 @@ class StableDiffusionTool(BuiltinTool): image = response.json()["images"][0] return self.create_blob_message( - blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) except Exception as e: diff --git a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py index 93803d7937..aeaef08805 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py +++ b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py @@ -45,5 +45,5 @@ class PoiSearchTool(BuiltinTool): ).content return self.create_blob_message( - blob=result, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=result, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py index 3ba4996be1..4bd601c0bd 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -32,7 +32,7 @@ class VectorizerTool(BuiltinTool): if image_id.startswith("__test_"): image_binary = b64decode(VECTORIZER_ICON_PNG) else: - image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) + image_binary = self.get_variable_file(self.VariableKey.IMAGE) if not image_binary: return self.create_text_message("Image not found, please request user to generate image firstly.") diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index ac3dc84db4..d9e9a0faad 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -63,7 +63,7 @@ class Tool(BaseModel, ABC): def __init__(self, **data: Any): super().__init__(**data) - class VARIABLE_KEY(Enum): + class VariableKey(Enum): IMAGE = "image" def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": @@ -142,7 +142,7 @@ class Tool(BaseModel, ABC): if not self.variables: return None - return self.get_variable(self.VARIABLE_KEY.IMAGE) + return self.get_variable(self.VariableKey.IMAGE) def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: """ @@ -189,7 +189,7 @@ class Tool(BaseModel, ABC): result = [] for variable in self.variables.pool: - if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): + if variable.name.startswith(self.VariableKey.IMAGE.value): result.append(variable) return result diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index c6bd122b37..f4e87a42a7 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -8,7 +8,7 @@ from typing import Any, Optional from flask import Flask, current_app -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import ( NodeRunMetadataKey, @@ -669,7 +669,7 @@ class GraphEngine: parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: # trigger node run failed event route_node_state.status = RouteNodeState.Status.FAILED route_node_state.failed_reason = "Workflow stopped." diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 4a1787c8c1..a07ba2f740 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence from typing import Any, Optional, Union, cast from configs import dify_config -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider @@ -61,7 +61,7 @@ class CodeNode(BaseNode): # Transform result result = self._transform_result(result, node_data.outputs) - except (CodeExecutionException, ValueError) as e: + except (CodeExecutionError, ValueError) as e: return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 2829144ead..32c99e0d1c 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -2,7 +2,7 @@ import os from collections.abc import Mapping, Sequence from typing import Any, Optional, cast -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData @@ -45,7 +45,7 @@ class TemplateTransformNode(BaseNode): result = CodeExecutor.execute_workflow_code_template( language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables ) - except CodeExecutionException as e: + except CodeExecutionError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 25021935ee..74a598ada5 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -6,7 +6,7 @@ from typing import Any, Optional, cast from configs import dify_config from core.app.app_config.entities import FileExtraConfig -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import WorkflowCallback @@ -103,7 +103,7 @@ class WorkflowEntry: for callback in callbacks: callback.on_event(event=event) yield event - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except Exception as e: logger.exception("Unknown Error when workflow entry running") diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 72a135e73d..54f6a76e16 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -5,7 +5,7 @@ import time import click from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.event_handlers.document_index_event import document_index_created from extensions.ext_database import db from models.dataset import Document @@ -43,7 +43,7 @@ def handle(sender, **kwargs): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/extensions/storage/google_storage.py b/api/extensions/storage/google_storage.py index 9ed1fcf0b4..c42f946fa8 100644 --- a/api/extensions/storage/google_storage.py +++ b/api/extensions/storage/google_storage.py @@ -5,7 +5,7 @@ from collections.abc import Generator from contextlib import closing from flask import Flask -from google.cloud import storage as GoogleCloudStorage +from google.cloud import storage as google_cloud_storage from extensions.storage.base_storage import BaseStorage @@ -23,9 +23,9 @@ class GoogleStorage(BaseStorage): service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") # convert str to object service_account_obj = json.loads(service_account_json) - self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) + self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj) else: - self.client = GoogleCloudStorage.Client() + self.client = google_cloud_storage.Client() def save(self, filename, data): bucket = self.client.get_bucket(self.bucket_name) diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 2d306edb40..f89902c5e8 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -31,7 +31,7 @@ from Crypto.Util.py3compat import _copy_bytes, bord from Crypto.Util.strxor import strxor -class PKCS1OAEP_Cipher: +class PKCS1OAepCipher: """Cipher object for PKCS#1 v1.5 OAEP. Do not create directly: use :func:`new` instead.""" @@ -237,4 +237,4 @@ def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): if randfunc is None: randfunc = Random.get_random_bytes - return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc) + return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc) diff --git a/api/libs/helper.py b/api/libs/helper.py index af0c2dace1..d664ef1ae7 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -84,7 +84,7 @@ def timestamp_value(timestamp): raise ValueError(error) -class str_len: +class StrLen: """Restrict input to an integer in a range (inclusive)""" def __init__(self, max_length, argument="argument"): @@ -102,7 +102,7 @@ class str_len: return value -class float_range: +class FloatRange: """Restrict input to an float in a range (inclusive)""" def __init__(self, low, high, argument="argument"): @@ -121,7 +121,7 @@ class float_range: return value -class datetime_string: +class DatetimeString: def __init__(self, format, argument="argument"): self.format = format self.argument = argument diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 41d6905899..39c17534e7 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -1,6 +1,6 @@ import json -from core.llm_generator.output_parser.errors import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserError def parse_json_markdown(json_string: str) -> dict: @@ -33,10 +33,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: try: json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: - raise OutputParserException(f"Got invalid JSON object. Error: {e}") + raise OutputParserError(f"Got invalid JSON object. Error: {e}") for key in expected_keys: if key not in json_obj: - raise OutputParserException( + raise OutputParserError( f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" ) return json_obj diff --git a/api/pyproject.toml b/api/pyproject.toml index dc7e271ccf..45bd7e00d8 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -15,8 +15,8 @@ select = [ "C4", # flake8-comprehensions "F", # pyflakes rules "I", # isort rules + "N", # pep8-naming "UP", # pyupgrade rules - "B035", # static-key-dict-comprehension "E101", # mixed-spaces-and-tabs "E111", # indentation-with-invalid-multiple "E112", # no-indented-block @@ -47,9 +47,10 @@ ignore = [ "B006", # mutable-argument-default "B007", # unused-loop-control-variable "B026", # star-arg-unpacking-after-keyword-arg -# "B901", # return-in-generator "B904", # raise-without-from-inside-except "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope ] [tool.ruff.lint.per-file-ignores] @@ -65,6 +66,12 @@ ignore = [ "F401", # unused-import "F811", # redefined-while-unused ] +"configs/*" = [ + "N802", # invalid-function-name +] +"libs/gmpy2_pkcs10aep_cipher.py" = [ + "N803", # invalid-argument-name +] [tool.ruff.format] exclude = [ diff --git a/api/services/account_service.py b/api/services/account_service.py index e1b70fc9ed..7fb42f9e81 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -32,7 +32,7 @@ from services.errors.account import ( NoPermissionError, RateLimitExceededError, RoleAlreadyAssignedError, - TenantNotFound, + TenantNotFoundError, ) from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task @@ -311,13 +311,13 @@ class TenantService: """Get tenant by account and add the role""" tenant = account.current_tenant if not tenant: - raise TenantNotFound("Tenant not found.") + raise TenantNotFoundError("Tenant not found.") ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: tenant.role = ta.role else: - raise TenantNotFound("Tenant not found for the account.") + raise TenantNotFoundError("Tenant not found for the account.") return tenant @staticmethod @@ -614,8 +614,8 @@ class RegisterService: "email": account.email, "workspace_id": tenant.id, } - expiryHours = dify_config.INVITE_EXPIRY_HOURS - redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data)) + expiry_hours = dify_config.INVITE_EXPIRY_HOURS + redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data)) return token @classmethod diff --git a/api/services/errors/account.py b/api/services/errors/account.py index cae31c5066..82dd9f944a 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -1,7 +1,7 @@ from services.errors.base import BaseServiceError -class AccountNotFound(BaseServiceError): +class AccountNotFoundError(BaseServiceError): pass @@ -25,7 +25,7 @@ class LinkAccountIntegrateError(BaseServiceError): pass -class TenantNotFound(BaseServiceError): +class TenantNotFoundError(BaseServiceError): pass diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 9ea4c99649..6dd755ab03 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -106,7 +106,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logging.info( click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") ) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index e0da5f9ed0..72c4674e0f 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from configs import dify_config -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db from models.dataset import Dataset, Document from services.feature_service import FeatureService @@ -72,7 +72,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 6e681bcf4f..cb38bc668d 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -69,7 +69,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): indexing_runner.run([document]) end_at = time.perf_counter() logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 0a7568c385..f4c3dbd2e2 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from configs import dify_config -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -88,7 +88,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 18bae14ffa..21ea11d4dd 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db from models.dataset import Document @@ -39,7 +39,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): logging.info( click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") ) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py index b37b109eba..83317e59de 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -70,6 +70,7 @@ class MockTEIClass: }, } + @staticmethod def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: # Example response: # [ diff --git a/api/tests/integration_tests/tools/__mock/http.py b/api/tests/integration_tests/tools/__mock/http.py index 4dfc530010..d3c1f3101c 100644 --- a/api/tests/integration_tests/tools/__mock/http.py +++ b/api/tests/integration_tests/tools/__mock/http.py @@ -7,6 +7,7 @@ from _pytest.monkeypatch import MonkeyPatch class MockedHttp: + @staticmethod def httpx_request( method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs ) -> httpx.Response: diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 571c1e3d44..53c9b3cae3 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -13,7 +13,7 @@ from xinference_client.types import Embedding class MockTcvectordbClass: - def VectorDBClient( + def mock_vector_db_client( self, url=None, username="", @@ -110,7 +110,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient) + monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client) monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index cfc47bcad4..f1ab23b002 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -10,6 +10,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockedHttp: + @staticmethod def httpx_request( method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs ) -> httpx.Response: diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py index 44dcf9a10f..487178ff58 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -1,11 +1,11 @@ import pytest -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor CODE_LANGUAGE = "unsupported_language" def test_unsupported_with_code_template(): - with pytest.raises(CodeExecutionException) as e: + with pytest.raises(CodeExecutionError) as e: CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"