chore: apply pep8-naming rules for naming convention (#8261)

This commit is contained in:
Bowen Liang 2024-09-11 16:40:52 +08:00 committed by GitHub
parent 53f37a6704
commit 292220c596
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
95 changed files with 287 additions and 258 deletions

View File

@ -20,7 +20,7 @@ from fields.conversation_fields import (
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.helper import datetime_string
from libs.helper import DatetimeString
from libs.login import login_required
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)

View File

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.helper import datetime_string
from libs.helper import DatetimeString
from libs.login import login_required
from models.model import AppMode
@ -25,8 +25,8 @@ class DailyMessageStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -79,8 +79,8 @@ class DailyConversationStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -133,8 +133,8 @@ class DailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -187,8 +187,8 @@ class DailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -245,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
@ -307,8 +307,8 @@ class UserSatisfactionRateStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -369,8 +369,8 @@ class AverageResponseTimeStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -425,8 +425,8 @@ class TokensPerSecondStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,

View File

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.helper import datetime_string
from libs.helper import DatetimeString
from libs.login import login_required
from models.model import AppMode
from models.workflow import WorkflowRunTriggeredFrom
@ -26,8 +26,8 @@ class WorkflowDailyRunsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -86,8 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -146,8 +146,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -213,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """

View File

@ -8,7 +8,7 @@ from constants.languages import supported_language
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.helper import email, str_len, timezone
from libs.helper import StrLen, email, timezone
from libs.password import hash_password, valid_password
from models.account import AccountStatus
from services.account_service import RegisterService
@ -37,7 +37,7 @@ class ActivateApi(Resource):
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"

View File

@ -4,7 +4,7 @@ from flask import session
from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import str_len
from libs.helper import StrLen
from models.model import DifySetup
from services.account_service import TenantService
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
raise AlreadySetupError()
parser = reqparse.RequestParser()
parser.add_argument("password", type=str_len(30), required=True, location="json")
parser.add_argument("password", type=StrLen(30), required=True, location="json")
input_password = parser.parse_args()["password"]
if input_password != os.environ.get("INIT_PASSWORD"):

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import email, get_remote_ip, str_len
from libs.helper import StrLen, email, get_remote_ip
from libs.password import valid_password
from models.model import DifySetup
from services.account_service import RegisterService, TenantService
@ -40,7 +40,7 @@ class SetupApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("name", type=str_len(30), required=True, location="json")
parser.add_argument("name", type=StrLen(30), required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args()

View File

@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
@ -293,7 +293,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
)
runner.run()
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
@ -349,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
return generate_task_pipeline.process()
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException()
raise GenerateTaskStoppedError()
else:
logger.exception(e)
raise e

View File

@ -21,7 +21,7 @@ class AudioTrunk:
self.status = status
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
@ -81,7 +81,7 @@ class AppGeneratorTTSPublisher:
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
break
@ -97,7 +97,7 @@ class AppGeneratorTTSPublisher:
self.MAX_SENTENCE += 1
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
if text_tmp:
@ -110,7 +110,7 @@ class AppGeneratorTTSPublisher:
break
future_queue.put(None)
def checkAndGetAudio(self) -> AudioTrunk | None:
def check_and_get_audio(self) -> AudioTrunk | None:
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:

View File

@ -19,7 +19,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
)
from core.moderation.base import ModerationException
from core.moderation.base import ModerationError
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
@ -217,7 +217,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query=query,
message_id=message_id,
)
except ModerationException as e:
except ModerationError as e:
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
return True

View File

@ -179,10 +179,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream_response=stream_response,
)
def _listenAudioMsg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
@ -204,7 +204,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -217,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
try:
if not tts_publisher:
break
audio_trunk = tts_publisher.checkAndGetAudio()
audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)

View File

@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
@ -205,7 +205,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
)
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(

View File

@ -15,7 +15,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationException
from core.moderation.base import ModerationError
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought
@ -103,7 +103,7 @@ class AgentChatAppRunner(AppRunner):
query=query,
message_id=message.id,
)
except ModerationException as e:
except ModerationError as e:
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,

View File

@ -171,5 +171,5 @@ class AppQueueManager:
)
class GenerateTaskStoppedException(Exception):
class GenerateTaskStoppedError(Exception):
pass

View File

@ -10,7 +10,7 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
@ -205,7 +205,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
)
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(

View File

@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db
from models.model import App, Conversation, Message
@ -98,7 +98,7 @@ class ChatAppRunner(AppRunner):
query=query,
message_id=message.id,
)
except ModerationException as e:
except ModerationError as e:
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,

View File

@ -10,7 +10,7 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
@ -185,7 +185,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
queue_manager=queue_manager,
message=message,
)
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(

View File

@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db
from models.model import App, Message
@ -79,7 +79,7 @@ class CompletionAppRunner(AppRunner):
query=query,
message_id=message.id,
)
except ModerationException as e:
except ModerationError as e:
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,

View File

@ -8,7 +8,7 @@ from sqlalchemy import and_
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
@ -77,7 +77,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return generate_task_pipeline.process()
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException()
raise GenerateTaskStoppedError()
else:
logger.exception(e)
raise e

View File

@ -1,4 +1,4 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
@ -53,4 +53,4 @@ class MessageBasedAppQueueManager(AppQueueManager):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException()
raise GenerateTaskStoppedError()

View File

@ -12,7 +12,7 @@ from pydantic import ValidationError
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
@ -253,7 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
)
runner.run()
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
@ -302,7 +302,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
return generate_task_pipeline.process()
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException()
raise GenerateTaskStoppedError()
else:
logger.exception(e)
raise e

View File

@ -1,4 +1,4 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
@ -39,4 +39,4 @@ class WorkflowAppQueueManager(AppQueueManager):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException()
raise GenerateTaskStoppedError()

View File

@ -162,10 +162,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listenAudioMsg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
@ -187,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -199,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
try:
if not tts_publisher:
break
audio_trunk = tts_publisher.checkAndGetAudio()
audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)

View File

@ -15,6 +15,7 @@ class Segment(BaseModel):
value: Any
@field_validator("value_type")
@classmethod
def validate_value_type(cls, value):
"""
This validator checks if the provided value is equal to the default value of the 'value_type' field.

View File

@ -201,10 +201,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
stream_response=stream_response,
)
def _listenAudioMsg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher, task_id: str):
if publisher is None:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
@ -225,7 +225,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id)
audio_response = self._listen_audio_msg(publisher, task_id)
if audio_response:
yield audio_response
else:
@ -237,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
if publisher is None:
break
audio = publisher.checkAndGetAudio()
audio = publisher.check_and_get_audio()
if audio is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)

View File

@ -16,7 +16,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
logger = logging.getLogger(__name__)
class CodeExecutionException(Exception):
class CodeExecutionError(Exception):
pass
@ -86,15 +86,15 @@ class CodeExecutor:
),
)
if response.status_code == 503:
raise CodeExecutionException("Code execution service is unavailable")
raise CodeExecutionError("Code execution service is unavailable")
elif response.status_code != 200:
raise Exception(
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
)
except CodeExecutionException as e:
except CodeExecutionError as e:
raise e
except Exception as e:
raise CodeExecutionException(
raise CodeExecutionError(
"Failed to execute code, which is likely a network issue,"
" please check if the sandbox service is running."
f" ( Error: {str(e)} )"
@ -103,15 +103,15 @@ class CodeExecutor:
try:
response = response.json()
except:
raise CodeExecutionException("Failed to parse response")
raise CodeExecutionError("Failed to parse response")
if (code := response.get("code")) != 0:
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}")
response = CodeExecutionResponse(**response)
if response.data.error:
raise CodeExecutionException(response.data.error)
raise CodeExecutionError(response.data.error)
return response.data.stdout or ""
@ -126,13 +126,13 @@ class CodeExecutor:
"""
template_transformer = cls.code_template_transformers.get(language)
if not template_transformer:
raise CodeExecutionException(f"Unsupported language {language}")
raise CodeExecutionError(f"Unsupported language {language}")
runner, preload = template_transformer.transform_caller(code, inputs)
try:
response = cls.execute_code(language, preload, runner)
except CodeExecutionException as e:
except CodeExecutionError as e:
raise e
return template_transformer.transform_response(response)

View File

@ -78,8 +78,8 @@ class IndexingRunner:
dataset_document=dataset_document,
documents=documents,
)
except DocumentIsPausedException:
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
except DocumentIsPausedError:
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
@ -134,8 +134,8 @@ class IndexingRunner:
self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
except DocumentIsPausedError:
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
@ -192,8 +192,8 @@ class IndexingRunner:
self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
except DocumentIsPausedError:
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
@ -756,7 +756,7 @@ class IndexingRunner:
indexing_cache_key = "document_{}_is_paused".format(document_id)
result = redis_client.get(indexing_cache_key)
if result:
raise DocumentIsPausedException()
raise DocumentIsPausedError()
@staticmethod
def _update_document_index_status(
@ -767,10 +767,10 @@ class IndexingRunner:
"""
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedException()
raise DocumentIsPausedError()
document = DatasetDocument.query.filter_by(id=document_id).first()
if not document:
raise DocumentIsDeletedPausedException()
raise DocumentIsDeletedPausedError()
update_params = {DatasetDocument.indexing_status: after_indexing_status}
@ -875,9 +875,9 @@ class IndexingRunner:
pass
class DocumentIsPausedException(Exception):
class DocumentIsPausedError(Exception):
pass
class DocumentIsDeletedPausedException(Exception):
class DocumentIsDeletedPausedError(Exception):
pass

View File

@ -1,2 +1,2 @@
class OutputParserException(Exception):
class OutputParserError(Exception):
pass

View File

@ -1,6 +1,6 @@
from typing import Any
from core.llm_generator.output_parser.errors import OutputParserException
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.prompts import (
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
@ -29,4 +29,4 @@ class RuleConfigGeneratorOutputParser:
raise ValueError("Expected 'opening_statement' to be a str.")
return parsed
except Exception as e:
raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")
raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")

View File

@ -7,7 +7,7 @@ from requests import post
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
@ -124,7 +124,7 @@ class BaichuanModel:
if err == "invalid_api_key":
raise InvalidAPIKeyError(msg)
elif err == "insufficient_quota":
raise InsufficientAccountBalance(msg)
raise InsufficientAccountBalanceError(msg)
elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg)
elif err == "invalid_request_error":

View File

@ -10,7 +10,7 @@ class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
class InsufficientAccountBalanceError(Exception):
pass

View File

@ -29,7 +29,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import B
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
@ -289,7 +289,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InvalidAPIKeyError,
],
InvokeBadRequestError: [BadRequestError, KeyError],

View File

@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
@ -109,7 +109,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
if err == "invalid_api_key":
raise InvalidAPIKeyError(msg)
elif err == "insufficient_quota":
raise InsufficientAccountBalance(msg)
raise InsufficientAccountBalanceError(msg)
elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg)
elif err and "rate" in err:
@ -166,7 +166,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InvalidAPIKeyError,
],
InvokeBadRequestError: [BadRequestError, KeyError],

View File

@ -10,7 +10,7 @@ from core.model_runtime.errors.invoke import (
)
class _CommonOAI_API_Compat:
class _CommonOaiApiCompat:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""

View File

@ -35,13 +35,13 @@ from core.model_runtime.entities.model_entities import (
from core.model_runtime.errors.invoke import InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
from core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
"""
Model class for OpenAI large language model.
"""

View File

@ -6,10 +6,10 @@ import requests
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel):
class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel):
"""
Model class for OpenAI Compatible Speech to text model.
"""

View File

@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
"""
Model class for an OpenAI API-compatible text embedding model.
"""

View File

@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
"""
Model class for an OpenAI API-compatible text embedding model.
"""

View File

@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService
class MaaSClient(MaasService):
@ -106,7 +106,7 @@ class MaaSClient(MaasService):
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
try:
resp = fn()
except MaasException as e:
except MaasError as e:
raise wrap_error(e)
return resp

View File

@ -1,144 +1,144 @@
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError
class ClientSDKRequestError(MaasException):
class ClientSDKRequestError(MaasError):
pass
class SignatureDoesNotMatch(MaasException):
class SignatureDoesNotMatchError(MaasError):
pass
class RequestTimeout(MaasException):
class RequestTimeoutError(MaasError):
pass
class ServiceConnectionTimeout(MaasException):
class ServiceConnectionTimeoutError(MaasError):
pass
class MissingAuthenticationHeader(MaasException):
class MissingAuthenticationHeaderError(MaasError):
pass
class AuthenticationHeaderIsInvalid(MaasException):
class AuthenticationHeaderIsInvalidError(MaasError):
pass
class InternalServiceError(MaasException):
class InternalServiceError(MaasError):
pass
class MissingParameter(MaasException):
class MissingParameterError(MaasError):
pass
class InvalidParameter(MaasException):
class InvalidParameterError(MaasError):
pass
class AuthenticationExpire(MaasException):
class AuthenticationExpireError(MaasError):
pass
class EndpointIsInvalid(MaasException):
class EndpointIsInvalidError(MaasError):
pass
class EndpointIsNotEnable(MaasException):
class EndpointIsNotEnableError(MaasError):
pass
class ModelNotSupportStreamMode(MaasException):
class ModelNotSupportStreamModeError(MaasError):
pass
class ReqTextExistRisk(MaasException):
class ReqTextExistRiskError(MaasError):
pass
class RespTextExistRisk(MaasException):
class RespTextExistRiskError(MaasError):
pass
class EndpointRateLimitExceeded(MaasException):
class EndpointRateLimitExceededError(MaasError):
pass
class ServiceConnectionRefused(MaasException):
class ServiceConnectionRefusedError(MaasError):
pass
class ServiceConnectionClosed(MaasException):
class ServiceConnectionClosedError(MaasError):
pass
class UnauthorizedUserForEndpoint(MaasException):
class UnauthorizedUserForEndpointError(MaasError):
pass
class InvalidEndpointWithNoURL(MaasException):
class InvalidEndpointWithNoURLError(MaasError):
pass
class EndpointAccountRpmRateLimitExceeded(MaasException):
class EndpointAccountRpmRateLimitExceededError(MaasError):
pass
class EndpointAccountTpmRateLimitExceeded(MaasException):
class EndpointAccountTpmRateLimitExceededError(MaasError):
pass
class ServiceResourceWaitQueueFull(MaasException):
class ServiceResourceWaitQueueFullError(MaasError):
pass
class EndpointIsPending(MaasException):
class EndpointIsPendingError(MaasError):
pass
class ServiceNotOpen(MaasException):
class ServiceNotOpenError(MaasError):
pass
AuthErrors = {
"SignatureDoesNotMatch": SignatureDoesNotMatch,
"MissingAuthenticationHeader": MissingAuthenticationHeader,
"AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalid,
"AuthenticationExpire": AuthenticationExpire,
"UnauthorizedUserForEndpoint": UnauthorizedUserForEndpoint,
"SignatureDoesNotMatch": SignatureDoesNotMatchError,
"MissingAuthenticationHeader": MissingAuthenticationHeaderError,
"AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError,
"AuthenticationExpire": AuthenticationExpireError,
"UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError,
}
BadRequestErrors = {
"MissingParameter": MissingParameter,
"InvalidParameter": InvalidParameter,
"EndpointIsInvalid": EndpointIsInvalid,
"EndpointIsNotEnable": EndpointIsNotEnable,
"ModelNotSupportStreamMode": ModelNotSupportStreamMode,
"ReqTextExistRisk": ReqTextExistRisk,
"RespTextExistRisk": RespTextExistRisk,
"InvalidEndpointWithNoURL": InvalidEndpointWithNoURL,
"ServiceNotOpen": ServiceNotOpen,
"MissingParameter": MissingParameterError,
"InvalidParameter": InvalidParameterError,
"EndpointIsInvalid": EndpointIsInvalidError,
"EndpointIsNotEnable": EndpointIsNotEnableError,
"ModelNotSupportStreamMode": ModelNotSupportStreamModeError,
"ReqTextExistRisk": ReqTextExistRiskError,
"RespTextExistRisk": RespTextExistRiskError,
"InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError,
"ServiceNotOpen": ServiceNotOpenError,
}
RateLimitErrors = {
"EndpointRateLimitExceeded": EndpointRateLimitExceeded,
"EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceeded,
"EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceeded,
"EndpointRateLimitExceeded": EndpointRateLimitExceededError,
"EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError,
"EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError,
}
ServerUnavailableErrors = {
"InternalServiceError": InternalServiceError,
"EndpointIsPending": EndpointIsPending,
"ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFull,
"EndpointIsPending": EndpointIsPendingError,
"ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError,
}
ConnectionErrors = {
"ClientSDKRequestError": ClientSDKRequestError,
"RequestTimeout": RequestTimeout,
"ServiceConnectionTimeout": ServiceConnectionTimeout,
"ServiceConnectionRefused": ServiceConnectionRefused,
"ServiceConnectionClosed": ServiceConnectionClosed,
"RequestTimeout": RequestTimeoutError,
"ServiceConnectionTimeout": ServiceConnectionTimeoutError,
"ServiceConnectionRefused": ServiceConnectionRefusedError,
"ServiceConnectionClosed": ServiceConnectionClosedError,
}
ErrorCodeMap = {
@ -150,7 +150,7 @@ ErrorCodeMap = {
}
def wrap_error(e: MaasException) -> Exception:
def wrap_error(e: MaasError) -> Exception:
if ErrorCodeMap.get(e.code):
return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
return e

View File

@ -1,4 +1,4 @@
from .common import ChatRole
from .maas import MaasException, MaasService
from .maas import MaasError, MaasService
__all__ = ["MaasService", "ChatRole", "MaasException"]
__all__ = ["MaasService", "ChatRole", "MaasError"]

View File

@ -63,7 +63,7 @@ class MaasService(Service):
raise
if res.error is not None and res.error.code_n != 0:
raise MaasException(
raise MaasError(
res.error.code_n,
res.error.code,
res.error.message,
@ -72,7 +72,7 @@ class MaasService(Service):
yield res
return iter_fn()
except MaasException:
except MaasError:
raise
except Exception as e:
raise new_client_sdk_request_error(str(e))
@ -94,7 +94,7 @@ class MaasService(Service):
resp["req_id"] = req_id
return resp
except MaasException as e:
except MaasError as e:
raise e
except Exception as e:
raise new_client_sdk_request_error(str(e), req_id)
@ -147,14 +147,14 @@ class MaasService(Service):
raise new_client_sdk_request_error(raw, req_id)
if resp.error:
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, req_id)
raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id)
else:
raise new_client_sdk_request_error(resp, req_id)
return res
class MaasException(Exception):
class MaasError(Exception):
def __init__(self, code_n, code, message, req_id):
self.code_n = code_n
self.code = code
@ -172,7 +172,7 @@ class MaasException(Exception):
def new_client_sdk_request_error(raw, req_id=""):
return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
class BinaryResponseContent:
@ -192,7 +192,7 @@ class BinaryResponseContent:
if len(error_bytes) > 0:
resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id)
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, self.request_id)
raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id)
def iter_bytes(self) -> Iterator[bytes]:
yield from self.response

View File

@ -35,7 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
AuthErrors,
BadRequestErrors,
ConnectionErrors,
MaasException,
MaasError,
RateLimitErrors,
ServerUnavailableErrors,
)
@ -85,7 +85,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
},
[UserPromptMessage(content="ping\nAnswer: ")],
)
except MaasException as e:
except MaasError as e:
raise CredentialsValidateFailedError(e.message)
@staticmethod

View File

@ -28,7 +28,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
AuthErrors,
BadRequestErrors,
ConnectionErrors,
MaasException,
MaasError,
RateLimitErrors,
ServerUnavailableErrors,
)
@ -111,7 +111,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
try:
self._invoke(model=model, credentials=credentials, texts=["ping"])
except MaasException as e:
except MaasError as e:
raise CredentialsValidateFailedError(e.message)
def _validate_credentials_v3(self, model: str, credentials: dict) -> None:

View File

@ -23,7 +23,7 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InvalidAPIKeyError,
],
InvokeBadRequestError: [BadRequestError, KeyError],
@ -42,7 +42,7 @@ class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
class InsufficientAccountBalanceError(Exception):
pass

View File

@ -76,7 +76,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@classmethod
def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None:
# inputs_config
inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict):
@ -111,5 +111,5 @@ class Moderation(Extensible, ABC):
raise ValueError("outputs_config.preset_response must be less than 100 characters")
class ModerationException(Exception):
class ModerationError(Exception):
pass

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
from core.app.app_config.entities import AppConfig
from core.moderation.base import ModerationAction, ModerationException
from core.moderation.base import ModerationAction, ModerationError
from core.moderation.factory import ModerationFactory
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
@ -61,7 +61,7 @@ class InputModeration:
return False, inputs, query
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
raise ModerationException(moderation_result.preset_response)
raise ModerationError(moderation_result.preset_response)
elif moderation_result.action == ModerationAction.OVERRIDDEN:
inputs = moderation_result.inputs
query = moderation_result.query

View File

@ -26,6 +26,7 @@ class LangfuseConfig(BaseTracingConfig):
host: str = "https://api.langfuse.com"
@field_validator("host")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://api.langfuse.com"
@ -45,6 +46,7 @@ class LangSmithConfig(BaseTracingConfig):
endpoint: str = "https://api.smith.langchain.com"
@field_validator("endpoint")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://api.smith.langchain.com"

View File

@ -15,6 +15,7 @@ class BaseTraceInfo(BaseModel):
metadata: dict[str, Any]
@field_validator("inputs", "outputs")
@classmethod
def ensure_type(cls, v):
if v is None:
return None

View File

@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel):
)
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)
@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel):
)
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)
@ -196,6 +198,7 @@ class GenerationUsage(BaseModel):
totalCost: Optional[float] = None
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)
@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel):
model_config = ConfigDict(protected_namespaces=())
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)

View File

@ -51,6 +51,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
@field_validator("inputs", "outputs")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
values = info.data
@ -115,6 +116,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
return v
return v
@classmethod
@field_validator("start_time", "end_time")
def format_time(cls, v, info: ValidationInfo):
if not isinstance(v, datetime):

View File

@ -27,6 +27,7 @@ class ElasticSearchConfig(BaseModel):
password: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config HOST is required")

View File

@ -28,6 +28,7 @@ class MilvusConfig(BaseModel):
database: str = "default"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")

View File

@ -29,6 +29,7 @@ class OpenSearchConfig(BaseModel):
secure: bool = False
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values.get("host"):
raise ValueError("config OPENSEARCH_HOST is required")

View File

@ -32,6 +32,7 @@ class OracleVectorConfig(BaseModel):
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config ORACLE_HOST is required")

View File

@ -32,6 +32,7 @@ class PgvectoRSConfig(BaseModel):
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config PGVECTO_RS_HOST is required")

View File

@ -25,6 +25,7 @@ class PGVectorConfig(BaseModel):
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config PGVECTOR_HOST is required")

View File

@ -34,6 +34,7 @@ class RelytConfig(BaseModel):
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config RELYT_HOST is required")

View File

@ -29,6 +29,7 @@ class TiDBVectorConfig(BaseModel):
program_name: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config TIDB_VECTOR_HOST is required")

View File

@ -23,6 +23,7 @@ class WeaviateConfig(BaseModel):
batch_size: int = 100
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required")

View File

@ -69,7 +69,7 @@ class DallE3Tool(BuiltinTool):
self.create_blob_message(
blob=b64decode(image.b64_json),
meta={"mime_type": "image/png"},
save_as=self.VARIABLE_KEY.IMAGE.value,
save_as=self.VariableKey.IMAGE.value,
)
)
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))

View File

@ -59,7 +59,7 @@ class DallE2Tool(BuiltinTool):
self.create_blob_message(
blob=b64decode(image.b64_json),
meta={"mime_type": "image/png"},
save_as=self.VARIABLE_KEY.IMAGE.value,
save_as=self.VariableKey.IMAGE.value,
)
)

View File

@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
for image in response.data:
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
blob_message = self.create_blob_message(
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
)
result.append(blob_message)
return result

View File

@ -34,7 +34,7 @@ class NovitaAiCreateTileTool(BuiltinTool):
self.create_blob_message(
blob=b64decode(client_result.image_file),
meta={"mime_type": f"image/{client_result.image_type}"},
save_as=self.VARIABLE_KEY.IMAGE.value,
save_as=self.VariableKey.IMAGE.value,
)
)

View File

@ -40,7 +40,7 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
self.create_blob_message(
blob=b64decode(image_encoded),
meta={"mime_type": f"image/{image.image_type}"},
save_as=self.VARIABLE_KEY.IMAGE.value,
save_as=self.VariableKey.IMAGE.value,
)
)

View File

@ -46,7 +46,7 @@ class QRCodeGeneratorTool(BuiltinTool):
image = self._generate_qrcode(content, border, error_correction)
image_bytes = self._image_to_byte_array(image)
return self.create_blob_message(
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
)
except Exception:
logging.exception(f"Failed to generate QR code for content: {content}")

View File

@ -32,5 +32,5 @@ class FluxTool(BuiltinTool):
res = response.json()
result = [self.create_json_message(res)]
for image in res.get("images", []):
result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value))
result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value))
return result

View File

@ -41,5 +41,5 @@ class StableDiffusionTool(BuiltinTool):
res = response.json()
result = [self.create_json_message(res)]
for image in res.get("images", []):
result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value))
result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value))
return result

View File

@ -15,16 +15,16 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class AssembleHeaderException(Exception):
class AssembleHeaderError(Exception):
def __init__(self, msg):
self.message = msg
class Url:
def __init__(this, host, path, schema):
this.host = host
this.path = path
this.schema = schema
def __init__(self, host, path, schema):
self.host = host
self.path = path
self.schema = schema
# calculate sha256 and encode to base64
@ -41,7 +41,7 @@ def parse_url(request_url):
schema = request_url[: stidx + 3]
edidx = host.index("/")
if edidx <= 0:
raise AssembleHeaderException("invalid request url:" + request_url)
raise AssembleHeaderError("invalid request url:" + request_url)
path = host[edidx:]
host = host[:edidx]
u = Url(host, path, schema)
@ -115,7 +115,7 @@ class SparkImgGeneratorTool(BuiltinTool):
self.create_blob_message(
blob=b64decode(image["base64_image"]),
meta={"mime_type": "image/png"},
save_as=self.VARIABLE_KEY.IMAGE.value,
save_as=self.VariableKey.IMAGE.value,
)
)
return result

View File

@ -52,5 +52,5 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
raise Exception(response.text)
return self.create_blob_message(
blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
)

View File

@ -260,7 +260,7 @@ class StableDiffusionTool(BuiltinTool):
image = response.json()["images"][0]
return self.create_blob_message(
blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
)
except Exception as e:
@ -294,7 +294,7 @@ class StableDiffusionTool(BuiltinTool):
image = response.json()["images"][0]
return self.create_blob_message(
blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
)
except Exception as e:

View File

@ -45,5 +45,5 @@ class PoiSearchTool(BuiltinTool):
).content
return self.create_blob_message(
blob=result, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
blob=result, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
)

View File

@ -32,7 +32,7 @@ class VectorizerTool(BuiltinTool):
if image_id.startswith("__test_"):
image_binary = b64decode(VECTORIZER_ICON_PNG)
else:
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
image_binary = self.get_variable_file(self.VariableKey.IMAGE)
if not image_binary:
return self.create_text_message("Image not found, please request user to generate image firstly.")

View File

@ -63,7 +63,7 @@ class Tool(BaseModel, ABC):
def __init__(self, **data: Any):
super().__init__(**data)
class VARIABLE_KEY(Enum):
class VariableKey(Enum):
IMAGE = "image"
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
@ -142,7 +142,7 @@ class Tool(BaseModel, ABC):
if not self.variables:
return None
return self.get_variable(self.VARIABLE_KEY.IMAGE)
return self.get_variable(self.VariableKey.IMAGE)
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
"""
@ -189,7 +189,7 @@ class Tool(BaseModel, ABC):
result = []
for variable in self.variables.pool:
if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
if variable.name.startswith(self.VariableKey.IMAGE.value):
result.append(variable)
return result

View File

@ -8,7 +8,7 @@ from typing import Any, Optional
from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
@ -669,7 +669,7 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."

View File

@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
@ -61,7 +61,7 @@ class CodeNode(BaseNode):
# Transform result
result = self._transform_result(result, node_data.outputs)
except (CodeExecutionException, ValueError) as e:
except (CodeExecutionError, ValueError) as e:
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)

View File

@ -2,7 +2,7 @@ import os
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
@ -45,7 +45,7 @@ class TemplateTransformNode(BaseNode):
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables
)
except CodeExecutionException as e:
except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, cast
from configs import dify_config
from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
@ -103,7 +103,7 @@ class WorkflowEntry:
for callback in callbacks:
callback.on_event(event=event)
yield event
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
pass
except Exception as e:
logger.exception("Unknown Error when workflow entry running")

View File

@ -5,7 +5,7 @@ import time
import click
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from events.event_handlers.document_index_event import document_index_created
from extensions.ext_database import db
from models.dataset import Document
@ -43,7 +43,7 @@ def handle(sender, **kwargs):
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
except DocumentIsPausedException as ex:
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass

View File

@ -5,7 +5,7 @@ from collections.abc import Generator
from contextlib import closing
from flask import Flask
from google.cloud import storage as GoogleCloudStorage
from google.cloud import storage as google_cloud_storage
from extensions.storage.base_storage import BaseStorage
@ -23,9 +23,9 @@ class GoogleStorage(BaseStorage):
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
# convert str to object
service_account_obj = json.loads(service_account_json)
self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj)
self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj)
else:
self.client = GoogleCloudStorage.Client()
self.client = google_cloud_storage.Client()
def save(self, filename, data):
bucket = self.client.get_bucket(self.bucket_name)

View File

@ -31,7 +31,7 @@ from Crypto.Util.py3compat import _copy_bytes, bord
from Crypto.Util.strxor import strxor
class PKCS1OAEP_Cipher:
class PKCS1OAepCipher:
"""Cipher object for PKCS#1 v1.5 OAEP.
Do not create directly: use :func:`new` instead."""
@ -237,4 +237,4 @@ def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None):
if randfunc is None:
randfunc = Random.get_random_bytes
return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc)
return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc)

View File

@ -84,7 +84,7 @@ def timestamp_value(timestamp):
raise ValueError(error)
class str_len:
class StrLen:
"""Restrict input to an integer in a range (inclusive)"""
def __init__(self, max_length, argument="argument"):
@ -102,7 +102,7 @@ class str_len:
return value
class float_range:
class FloatRange:
"""Restrict input to an float in a range (inclusive)"""
def __init__(self, low, high, argument="argument"):
@ -121,7 +121,7 @@ class float_range:
return value
class datetime_string:
class DatetimeString:
def __init__(self, format, argument="argument"):
self.format = format
self.argument = argument

View File

@ -1,6 +1,6 @@
import json
from core.llm_generator.output_parser.errors import OutputParserException
from core.llm_generator.output_parser.errors import OutputParserError
def parse_json_markdown(json_string: str) -> dict:
@ -33,10 +33,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
raise OutputParserError(f"Got invalid JSON object. Error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserException(
raise OutputParserError(
f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}"
)
return json_obj

View File

@ -15,8 +15,8 @@ select = [
"C4", # flake8-comprehensions
"F", # pyflakes rules
"I", # isort rules
"N", # pep8-naming
"UP", # pyupgrade rules
"B035", # static-key-dict-comprehension
"E101", # mixed-spaces-and-tabs
"E111", # indentation-with-invalid-multiple
"E112", # no-indented-block
@ -47,9 +47,10 @@ ignore = [
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
# "B901", # return-in-generator
"B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict
"N806", # non-lowercase-variable-in-function
"N815", # mixed-case-variable-in-class-scope
]
[tool.ruff.lint.per-file-ignores]
@ -65,6 +66,12 @@ ignore = [
"F401", # unused-import
"F811", # redefined-while-unused
]
"configs/*" = [
"N802", # invalid-function-name
]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
[tool.ruff.format]
exclude = [

View File

@ -32,7 +32,7 @@ from services.errors.account import (
NoPermissionError,
RateLimitExceededError,
RoleAlreadyAssignedError,
TenantNotFound,
TenantNotFoundError,
)
from tasks.mail_invite_member_task import send_invite_member_mail_task
from tasks.mail_reset_password_task import send_reset_password_mail_task
@ -311,13 +311,13 @@ class TenantService:
"""Get tenant by account and add the role"""
tenant = account.current_tenant
if not tenant:
raise TenantNotFound("Tenant not found.")
raise TenantNotFoundError("Tenant not found.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
if ta:
tenant.role = ta.role
else:
raise TenantNotFound("Tenant not found for the account.")
raise TenantNotFoundError("Tenant not found for the account.")
return tenant
@staticmethod
@ -614,8 +614,8 @@ class RegisterService:
"email": account.email,
"workspace_id": tenant.id,
}
expiryHours = dify_config.INVITE_EXPIRY_HOURS
redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data))
expiry_hours = dify_config.INVITE_EXPIRY_HOURS
redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data))
return token
@classmethod

View File

@ -1,7 +1,7 @@
from services.errors.base import BaseServiceError
class AccountNotFound(BaseServiceError):
class AccountNotFoundError(BaseServiceError):
pass
@ -25,7 +25,7 @@ class LinkAccountIntegrateError(BaseServiceError):
pass
class TenantNotFound(BaseServiceError):
class TenantNotFoundError(BaseServiceError):
pass

View File

@ -6,7 +6,7 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@ -106,7 +106,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logging.info(
click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")
)
except DocumentIsPausedException as ex:
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass

View File

@ -6,7 +6,7 @@ import click
from celery import shared_task
from configs import dify_config
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
@ -72,7 +72,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
except DocumentIsPausedException as ex:
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass

View File

@ -6,7 +6,7 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@ -69,7 +69,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
indexing_runner.run([document])
end_at = time.perf_counter()
logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green"))
except DocumentIsPausedException as ex:
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass

View File

@ -6,7 +6,7 @@ import click
from celery import shared_task
from configs import dify_config
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@ -88,7 +88,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
except DocumentIsPausedException as ex:
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass

View File

@ -5,7 +5,7 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
from models.dataset import Document
@ -39,7 +39,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
logging.info(
click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green")
)
except DocumentIsPausedException as ex:
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass

View File

@ -70,6 +70,7 @@ class MockTEIClass:
},
}
@staticmethod
def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
# Example response:
# [

View File

@ -7,6 +7,7 @@ from _pytest.monkeypatch import MonkeyPatch
class MockedHttp:
@staticmethod
def httpx_request(
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
) -> httpx.Response:

View File

@ -13,7 +13,7 @@ from xinference_client.types import Embedding
class MockTcvectordbClass:
def VectorDBClient(
def mock_vector_db_client(
self,
url=None,
username="",
@ -110,7 +110,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient)
monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases)
monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection)
monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections)

View File

@ -10,6 +10,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
class MockedHttp:
@staticmethod
def httpx_request(
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
) -> httpx.Response:

View File

@ -1,11 +1,11 @@
import pytest
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
CODE_LANGUAGE = "unsupported_language"
def test_unsupported_with_code_template():
with pytest.raises(CodeExecutionException) as e:
with pytest.raises(CodeExecutionError) as e:
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"