chore: refurish python code by applying Pylint linter rules (#8322)

This commit is contained in:
Bowen Liang 2024-09-13 22:42:08 +08:00 committed by GitHub
parent 1ab81b4972
commit a1104ab97e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
126 changed files with 253 additions and 272 deletions

View File

@ -164,7 +164,7 @@ def initialize_extensions(app):
@login_manager.request_loader @login_manager.request_loader
def load_user_from_request(request_from_flask_login): def load_user_from_request(request_from_flask_login):
"""Load user based on the request.""" """Load user based on the request."""
if request.blueprint not in ["console", "inner_api"]: if request.blueprint not in {"console", "inner_api"}:
return None return None
# Check if the user_id contains a dot, indicating the old format # Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")

View File

@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
@click.command("vdb-migrate", help="migrate vector db.") @click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str): def vdb_migrate(scope: str):
if scope in ["knowledge", "all"]: if scope in {"knowledge", "all"}:
migrate_knowledge_vector_database() migrate_knowledge_vector_database()
if scope in ["annotation", "all"]: if scope in {"annotation", "all"}:
migrate_annotation_vector_database() migrate_annotation_vector_database()

View File

@ -94,7 +94,7 @@ class ChatMessageTextApi(Resource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):

View File

@ -71,7 +71,7 @@ class OAuthCallback(Resource):
account = _generate_account(provider, user_info) account = _generate_account(provider, user_info)
# Check account status # Check account status
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
return {"error": "Account is banned or closed."}, 403 return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:

View File

@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
if document.indexing_status in ["completed", "error"]: if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule data_process_rule = document.dataset_process_rule
@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
info_list = [] info_list = []
extract_settings = [] extract_settings = []
for document in documents: for document in documents:
if document.indexing_status in ["completed", "error"]: if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
# format document files info # format document files info
@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit() db.session.commit()
elif action == "resume": elif action == "resume":
if document.indexing_status not in ["paused", "error"]: if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.") raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None document.paused_by = None

View File

@ -81,7 +81,7 @@ class ChatTextApi(InstalledAppResource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):

View File

@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id): def post(self, installed_app, task_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

View File

@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
def get(self, installed_app): def get(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id): def delete(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

View File

@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource):
"app_owner_tenant_id": installed_app.app_owner_tenant_id, "app_owner_tenant_id": installed_app.app_owner_tenant_id,
"is_pinned": installed_app.is_pinned, "is_pinned": installed_app.is_pinned,
"last_used_at": installed_app.last_used_at, "last_used_at": installed_app.last_used_at,
"editable": current_user.role in ["owner", "admin"], "editable": current_user.role in {"owner", "admin"},
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
} }
for installed_app in installed_apps for installed_app in installed_apps

View File

@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
message_id = str(message_id) message_id = str(message_id)

View File

@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model = installed_app.app app_model = installed_app.app
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

View File

@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
raise TooManyFilesError() raise TooManyFilesError()
extension = file.filename.split(".")[-1] extension = file.filename.split(".")[-1]
if extension.lower() not in ["svg", "png"]: if extension.lower() not in {"svg", "png"}:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
try: try:

View File

@ -42,7 +42,7 @@ class AppParameterApi(Resource):
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App): def get(self, app_model: App):
"""Retrieve app parameters.""" """Retrieve app parameters."""
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

View File

@ -79,7 +79,7 @@ class TextApi(Resource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):

View File

@ -96,7 +96,7 @@ class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser): def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -144,7 +144,7 @@ class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id): def post(self, app_model: App, end_user: EndUser, task_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationApi(Resource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser): def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def delete(self, app_model: App, end_user: EndUser, c_id): def delete(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model: App, end_user: EndUser, c_id): def post(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

View File

@ -76,7 +76,7 @@ class MessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser): def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
def get(self, app_model: App, end_user: EndUser, message_id): def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id) message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
try: try:

View File

@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App, end_user): def get(self, app_model: App, end_user):
"""Retrieve app parameters.""" """Retrieve app parameters."""
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

View File

@ -78,7 +78,7 @@ class TextApi(WebApiResource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):

View File

@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource): class ChatApi(WebApiResource):
def post(self, app_model, end_user): def post(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
class ChatStopApi(WebApiResource): class ChatStopApi(WebApiResource):
def post(self, app_model, end_user, task_id): def post(self, app_model, end_user, task_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
class ConversationApi(WebApiResource): class ConversationApi(WebApiResource):
def delete(self, app_model, end_user, c_id): def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
class ConversationPinApi(WebApiResource): class ConversationPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id): def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
class ConversationUnPinApi(WebApiResource): class ConversationUnPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id): def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

View File

@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
class MessageSuggestedQuestionApi(WebApiResource): class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id): def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotCompletionAppError() raise NotCompletionAppError()
message_id = str(message_id) message_id = str(message_id)

View File

@ -90,7 +90,7 @@ class CotAgentOutputParser:
if not in_code_block and not in_json: if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0: if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in ["\n", " ", ""]: if last_character not in {"\n", " ", ""}:
index += steps index += steps
yield delta yield delta
continue continue
@ -117,7 +117,7 @@ class CotAgentOutputParser:
action_idx = 0 action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0: if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in ["\n", " ", ""]: if last_character not in {"\n", " ", ""}:
index += steps index += steps
yield delta yield delta
continue continue

View File

@ -29,7 +29,7 @@ class BaseAppConfigManager:
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
additional_features.file_upload = FileUploadConfigManager.convert( additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
) )
additional_features.opening_statement, additional_features.suggested_questions = ( additional_features.opening_statement, additional_features.suggested_questions = (

View File

@ -18,7 +18,7 @@ class AgentConfigManager:
if agent_strategy == "function_call": if agent_strategy == "function_call":
strategy = AgentEntity.Strategy.FUNCTION_CALLING strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == "cot" or agent_strategy == "react": elif agent_strategy in {"cot", "react"}:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else: else:
# old configs, try to detect default strategy # old configs, try to detect default strategy
@ -43,10 +43,10 @@ class AgentConfigManager:
agent_tools.append(AgentToolEntity(**agent_tool_properties)) agent_tools.append(AgentToolEntity(**agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [ if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router", "react_router",
"router", "router",
]: }:
agent_prompt = agent_dict.get("prompt", None) or {} agent_prompt = agent_dict.get("prompt", None) or {}
# check model mode # check model mode
model_mode = config.get("model", {}).get("mode", "completion") model_mode = config.get("model", {}).get("mode", "completion")

View File

@ -167,7 +167,7 @@ class DatasetConfigManager:
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False has_datasets = False
if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
for tool in config["agent_mode"]["tools"]: for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key == "dataset": if key == "dataset":

View File

@ -42,12 +42,12 @@ class BasicVariablesConfigManager:
variable=variable["variable"], type=variable["type"], config=variable["config"] variable=variable["variable"], type=variable["type"], config=variable["config"]
) )
) )
elif variable_type in [ elif variable_type in {
VariableEntityType.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH, VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER, VariableEntityType.NUMBER,
VariableEntityType.SELECT, VariableEntityType.SELECT,
]: }:
variable = variables[variable_type] variable = variables[variable_type]
variable_entities.append( variable_entities.append(
VariableEntity( VariableEntity(
@ -97,7 +97,7 @@ class BasicVariablesConfigManager:
variables = [] variables = []
for item in config["user_input_form"]: for item in config["user_input_form"]:
key = list(item.keys())[0] key = list(item.keys())[0]
if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key] form_item = item[key]

View File

@ -54,14 +54,14 @@ class FileUploadConfigManager:
if is_vision: if is_vision:
detail = config["file_upload"]["image"]["detail"] detail = config["file_upload"]["image"]["detail"]
if detail not in ["high", "low"]: if detail not in {"high", "low"}:
raise ValueError("detail must be in ['high', 'low']") raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config["file_upload"]["image"]["transfer_methods"] transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list): if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type") raise ValueError("transfer_methods must be of list type")
for method in transfer_methods: for method in transfer_methods:
if method not in ["remote_url", "local_file"]: if method not in {"remote_url", "local_file"}:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"] return config, ["file_upload"]

View File

@ -73,7 +73,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
raise ValueError("Workflow not initialized") raise ValueError("Workflow not initialized")
user_id = None user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user: if end_user:
user_id = end_user.session_id user_id = end_user.session_id
@ -175,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
user_from=( user_from=(
UserFrom.ACCOUNT UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER else UserFrom.END_USER
), ),
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,

View File

@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
def convert( def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str, Any, None]: ) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response) return cls.convert_blocking_full_response(response)
else: else:

View File

@ -22,11 +22,11 @@ class BaseAppGenerator:
return var.default or "" return var.default or ""
if ( if (
var.type var.type
in ( in {
VariableEntityType.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT, VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH, VariableEntityType.PARAGRAPH,
) }
and user_input_value and user_input_value
and not isinstance(user_input_value, str) and not isinstance(user_input_value, str)
): ):
@ -44,7 +44,7 @@ class BaseAppGenerator:
options = var.options or [] options = var.options or []
if user_input_value not in options: if user_input_value not in options:
raise ValueError(f"{var.variable} in input form must be one of the following: {options}") raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
if var.max_length and user_input_value and len(user_input_value) > var.max_length: if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")

View File

@ -32,7 +32,7 @@ class AppQueueManager:
self._user_id = user_id self._user_id = user_id
self._invoke_from = invoke_from self._invoke_from = invoke_from
user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex( redis_client.setex(
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
) )
@ -118,7 +118,7 @@ class AppQueueManager:
if result is None: if result is None:
return return
user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
if result.decode("utf-8") != f"{user_prefix}-{user_id}": if result.decode("utf-8") != f"{user_prefix}-{user_id}":
return return

View File

@ -148,7 +148,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
# get from source # get from source
end_user_id = None end_user_id = None
account_id = None account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
from_source = "api" from_source = "api"
end_user_id = application_generate_entity.user_id end_user_id = application_generate_entity.user_id
else: else:
@ -165,11 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_provider = application_generate_entity.model_conf.provider model_provider = application_generate_entity.model_conf.provider
model_id = application_generate_entity.model_conf.model model_id = application_generate_entity.model_conf.model
override_model_configs = None override_model_configs = None
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [ if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in {
AppMode.AGENT_CHAT, AppMode.AGENT_CHAT,
AppMode.CHAT, AppMode.CHAT,
AppMode.COMPLETION, AppMode.COMPLETION,
]: }:
override_model_configs = app_config.app_model_config_dict override_model_configs = app_config.app_model_config_dict
# get conversation introduction # get conversation introduction

View File

@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = cast(WorkflowAppConfig, app_config) app_config = cast(WorkflowAppConfig, app_config)
user_id = None user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user: if end_user:
user_id = end_user.session_id user_id = end_user.session_id
@ -113,7 +113,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
user_from=( user_from=(
UserFrom.ACCOUNT UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER else UserFrom.END_USER
), ),
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,

View File

@ -63,7 +63,7 @@ class AnnotationReplyFeature:
score = documents[0].metadata["score"] score = documents[0].metadata["score"]
annotation = AppAnnotationService.get_annotation_by_id(annotation_id) annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation: if annotation:
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
from_source = "api" from_source = "api"
else: else:
from_source = "console" from_source = "console"

View File

@ -372,7 +372,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message, self._message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation, conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT] is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
and self._application_generate_entity.conversation_id is None, and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras, extras=self._application_generate_entity.extras,
) )

View File

@ -383,7 +383,7 @@ class WorkflowCycleManage:
:param workflow_node_execution: workflow node execution :param workflow_node_execution: workflow node execution
:return: :return:
""" """
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
response = NodeStartStreamResponse( response = NodeStartStreamResponse(
@ -430,7 +430,7 @@ class WorkflowCycleManage:
:param workflow_node_execution: workflow node execution :param workflow_node_execution: workflow node execution
:return: :return:
""" """
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
return NodeFinishStreamResponse( return NodeFinishStreamResponse(

View File

@ -29,7 +29,7 @@ class DatasetIndexToolCallbackHandler:
source="app", source="app",
source_app_id=self._app_id, source_app_id=self._app_id,
created_by_role=( created_by_role=(
"account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
), ),
created_by=self._user_id, created_by=self._user_id,
) )

View File

@ -292,7 +292,7 @@ class IndexingRunner:
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
) -> list[Document]: ) -> list[Document]:
# load file # load file
if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
return [] return []
data_source_info = dataset_document.data_source_info_dict data_source_info = dataset_document.data_source_info_dict

View File

@ -52,7 +52,7 @@ class TokenBufferMemory:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files: if files:
file_extra_config = None file_extra_config = None
if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
else: else:
if message.workflow_run_id: if message.workflow_run_id:

View File

@ -27,17 +27,17 @@ class ModelType(Enum):
:return: model type :return: model type
""" """
if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value: if origin_model_type in {"text-generation", cls.LLM.value}:
return cls.LLM return cls.LLM
elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value: elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
return cls.TEXT_EMBEDDING return cls.TEXT_EMBEDDING
elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value: elif origin_model_type in {"reranking", cls.RERANK.value}:
return cls.RERANK return cls.RERANK
elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value: elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
return cls.SPEECH2TEXT return cls.SPEECH2TEXT
elif origin_model_type == "tts" or origin_model_type == cls.TTS.value: elif origin_model_type in {"tts", cls.TTS.value}:
return cls.TTS return cls.TTS
elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value: elif origin_model_type in {"text2img", cls.TEXT2IMG.value}:
return cls.TEXT2IMG return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value: elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION return cls.MODERATION

View File

@ -494,7 +494,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
mime_type = data_split[0].replace("data:", "") mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1] base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError( raise ValueError(
f"Unsupported image type {mime_type}, " f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp" f"only support image/jpeg, image/png, image/gif, and image/webp"

View File

@ -85,14 +85,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
for i in range(len(sentences)) for i in range(len(sentences))
] ]
for future in futures: for future in futures:
yield from future.result().__enter__().iter_bytes(1024) yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
else: else:
response = client.audio.speech.with_streaming_response.create( response = client.audio.speech.with_streaming_response.create(
model=model, voice=voice, response_format="mp3", input=content_text.strip() model=model, voice=voice, response_format="mp3", input=content_text.strip()
) )
yield from response.__enter__().iter_bytes(1024) yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))

View File

@ -454,7 +454,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
base64_data = data_split[1] base64_data = data_split[1]
image_content = base64.b64decode(base64_data) image_content = base64.b64decode(base64_data)
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError( raise ValueError(
f"Unsupported image type {mime_type}, " f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp" f"only support image/jpeg, image/png, image/gif, and image/webp"
@ -886,16 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
if error_code == "AccessDeniedException": if error_code == "AccessDeniedException":
return InvokeAuthorizationError(error_msg) return InvokeAuthorizationError(error_msg)
elif error_code in ["ResourceNotFoundException", "ValidationException"]: elif error_code in {"ResourceNotFoundException", "ValidationException"}:
return InvokeBadRequestError(error_msg) return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
return InvokeRateLimitError(error_msg) return InvokeRateLimitError(error_msg)
elif error_code in [ elif error_code in {
"ModelTimeoutException", "ModelTimeoutException",
"ModelErrorException", "ModelErrorException",
"InternalServerException", "InternalServerException",
"ModelNotReadyException", "ModelNotReadyException",
]: }:
return InvokeServerUnavailableError(error_msg) return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException": elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg) return InvokeConnectionError(error_msg)

View File

@ -186,16 +186,16 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
if error_code == "AccessDeniedException": if error_code == "AccessDeniedException":
return InvokeAuthorizationError(error_msg) return InvokeAuthorizationError(error_msg)
elif error_code in ["ResourceNotFoundException", "ValidationException"]: elif error_code in {"ResourceNotFoundException", "ValidationException"}:
return InvokeBadRequestError(error_msg) return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
return InvokeRateLimitError(error_msg) return InvokeRateLimitError(error_msg)
elif error_code in [ elif error_code in {
"ModelTimeoutException", "ModelTimeoutException",
"ModelErrorException", "ModelErrorException",
"InternalServerException", "InternalServerException",
"ModelNotReadyException", "ModelNotReadyException",
]: }:
return InvokeServerUnavailableError(error_msg) return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException": elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg) return InvokeConnectionError(error_msg)

View File

@ -6,10 +6,10 @@ from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
import google.ai.generativelanguage as glm import google.ai.generativelanguage as glm
import google.api_core.exceptions as exceptions
import google.generativeai as genai import google.generativeai as genai
import google.generativeai.client as client
import requests import requests
from google.api_core import exceptions
from google.generativeai import client
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
from google.generativeai.types.content_types import to_part from google.generativeai.types.content_types import to_part
from PIL import Image from PIL import Image

View File

@ -77,7 +77,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
if "huggingfacehub_api_type" not in credentials: if "huggingfacehub_api_type" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"): if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
if "huggingfacehub_api_token" not in credentials: if "huggingfacehub_api_token" not in credentials:
@ -94,7 +94,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
credentials["huggingfacehub_api_token"], model credentials["huggingfacehub_api_token"], model
) )
if credentials["task_type"] not in ("text2text-generation", "text-generation"): if credentials["task_type"] not in {"text2text-generation", "text-generation"}:
raise CredentialsValidateFailedError( raise CredentialsValidateFailedError(
"Huggingface Hub Task Type must be one of text2text-generation, text-generation." "Huggingface Hub Task Type must be one of text2text-generation, text-generation."
) )

View File

@ -75,7 +75,7 @@ class TeiHelper:
if len(model_type.keys()) < 1: if len(model_type.keys()) < 1:
raise RuntimeError("model_type is empty") raise RuntimeError("model_type is empty")
model_type = list(model_type.keys())[0] model_type = list(model_type.keys())[0]
if model_type not in ["embedding", "reranker"]: if model_type not in {"embedding", "reranker"}:
raise RuntimeError(f"invalid model_type: {model_type}") raise RuntimeError(f"invalid model_type: {model_type}")
max_input_length = response_json.get("max_input_length", 512) max_input_length = response_json.get("max_input_length", 512)

View File

@ -100,9 +100,9 @@ class MinimaxChatCompletion:
return self._handle_chat_generate_response(response) return self._handle_chat_generate_response(response)
def _handle_error(self, code: int, msg: str): def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001 or code == 1013 or code == 1027: if code in {1000, 1001, 1013, 1027}:
raise InternalServerError(msg) raise InternalServerError(msg)
elif code == 1002 or code == 1039: elif code in {1002, 1039}:
raise RateLimitReachedError(msg) raise RateLimitReachedError(msg)
elif code == 1004: elif code == 1004:
raise InvalidAuthenticationError(msg) raise InvalidAuthenticationError(msg)

View File

@ -105,9 +105,9 @@ class MinimaxChatCompletionPro:
return self._handle_chat_generate_response(response) return self._handle_chat_generate_response(response)
def _handle_error(self, code: int, msg: str): def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001 or code == 1013 or code == 1027: if code in {1000, 1001, 1013, 1027}:
raise InternalServerError(msg) raise InternalServerError(msg)
elif code == 1002 or code == 1039: elif code in {1002, 1039}:
raise RateLimitReachedError(msg) raise RateLimitReachedError(msg)
elif code == 1004: elif code == 1004:
raise InvalidAuthenticationError(msg) raise InvalidAuthenticationError(msg)

View File

@ -114,7 +114,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError("Invalid api key") raise CredentialsValidateFailedError("Invalid api key")
def _handle_error(self, code: int, msg: str): def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001: if code in {1000, 1001}:
raise InternalServerError(msg) raise InternalServerError(msg)
elif code == 1002: elif code == 1002:
raise RateLimitReachedError(msg) raise RateLimitReachedError(msg)

View File

@ -125,7 +125,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model_mode = self.get_model_mode(base_model, credentials) model_mode = self.get_model_mode(base_model, credentials)
# transform response format # transform response format
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
stop = stop or [] stop = stop or []
if model_mode == LLMMode.CHAT: if model_mode == LLMMode.CHAT:
# chat model # chat model

View File

@ -89,14 +89,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
for i in range(len(sentences)) for i in range(len(sentences))
] ]
for future in futures: for future in futures:
yield from future.result().__enter__().iter_bytes(1024) yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
else: else:
response = client.audio.speech.with_streaming_response.create( response = client.audio.speech.with_streaming_response.create(
model=model, voice=voice, response_format="mp3", input=content_text.strip() model=model, voice=voice, response_format="mp3", input=content_text.strip()
) )
yield from response.__enter__().iter_bytes(1024) yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))

View File

@ -12,7 +12,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
credentials["endpoint_url"] = "https://openrouter.ai/api/v1" credentials["endpoint_url"] = "https://openrouter.ai/api/v1"
credentials["mode"] = self.get_model_mode(model).value credentials["mode"] = self.get_model_mode(model).value
credentials["function_calling_type"] = "tool_call" credentials["function_calling_type"] = "tool_call"
return
def _invoke( def _invoke(
self, self,

View File

@ -154,7 +154,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
) )
for key, value in input_properties: for key, value in input_properties:
if key not in ["system_prompt", "prompt"] and "stop" not in key: if key not in {"system_prompt", "prompt"} and "stop" not in key:
value_type = value.get("type") value_type = value.get("type")
if not value_type: if not value_type:

View File

@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
) )
for input_property in input_properties: for input_property in input_properties:
if input_property[0] in ("text", "texts", "inputs"): if input_property[0] in {"text", "texts", "inputs"}:
text_input_key = input_property[0] text_input_key = input_property[0]
return text_input_key return text_input_key
@ -96,7 +96,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
def _generate_embeddings_by_text_input_key( def _generate_embeddings_by_text_input_key(
client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str] client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str]
) -> list[list[float]]: ) -> list[list[float]]:
if text_input_key in ("text", "inputs"): if text_input_key in {"text", "inputs"}:
embeddings = [] embeddings = []
for text in texts: for text in texts:
result = client.run(replicate_model_version, input={text_input_key: text}) result = client.run(replicate_model_version, input={text_input_key: text})

View File

@ -89,7 +89,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
:param tools: tools for tool calling :param tools: tools for tool calling
:return: :return:
""" """
if model in ["qwen-turbo-chat", "qwen-plus-chat"]: if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
model = model.replace("-chat", "") model = model.replace("-chat", "")
if model == "farui-plus": if model == "farui-plus":
model = "qwen-farui-plus" model = "qwen-farui-plus"
@ -157,7 +157,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
mode = self.get_model_mode(model, credentials) mode = self.get_model_mode(model, credentials)
if model in ["qwen-turbo-chat", "qwen-plus-chat"]: if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
model = model.replace("-chat", "") model = model.replace("-chat", "")
extra_model_kwargs = {} extra_model_kwargs = {}
@ -201,7 +201,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:return: llm response :return: llm response
""" """
if response.status_code != 200 and response.status_code != HTTPStatus.OK: if response.status_code not in {200, HTTPStatus.OK}:
raise ServiceUnavailableError(response.message) raise ServiceUnavailableError(response.message)
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
@ -240,7 +240,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
full_text = "" full_text = ""
tool_calls = [] tool_calls = []
for index, response in enumerate(responses): for index, response in enumerate(responses):
if response.status_code != 200 and response.status_code != HTTPStatus.OK: if response.status_code not in {200, HTTPStatus.OK}:
raise ServiceUnavailableError( raise ServiceUnavailableError(
f"Failed to invoke model {model}, status code: {response.status_code}, " f"Failed to invoke model {model}, status code: {response.status_code}, "
f"message: {response.message}" f"message: {response.message}"

View File

@ -93,7 +93,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
""" """
Code block mode wrapper for invoking large language model Code block mode wrapper for invoking large language model
""" """
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
stop = stop or [] stop = stop or []
self._transform_chat_json_prompts( self._transform_chat_json_prompts(
model=model, model=model,

View File

@ -5,7 +5,6 @@ import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
import google.api_core.exceptions as exceptions
import google.auth.transport.requests import google.auth.transport.requests
import vertexai.generative_models as glm import vertexai.generative_models as glm
from anthropic import AnthropicVertex, Stream from anthropic import AnthropicVertex, Stream
@ -17,6 +16,7 @@ from anthropic.types import (
MessageStopEvent, MessageStopEvent,
MessageStreamEvent, MessageStreamEvent,
) )
from google.api_core import exceptions
from google.cloud import aiplatform from google.cloud import aiplatform
from google.oauth2 import service_account from google.oauth2 import service_account
from PIL import Image from PIL import Image
@ -346,7 +346,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
mime_type = data_split[0].replace("data:", "") mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1] base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError( raise ValueError(
f"Unsupported image type {mime_type}, " f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp" f"only support image/jpeg, image/png, image/gif, and image/webp"

View File

@ -96,7 +96,6 @@ class Signer:
signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials) request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials)
return
@staticmethod @staticmethod
def hashed_canonical_request_v4(request, meta): def hashed_canonical_request_v4(request, meta):
@ -105,7 +104,7 @@ class Signer:
signed_headers = {} signed_headers = {}
for key in request.headers: for key in request.headers:
if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"): if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"):
signed_headers[key.lower()] = request.headers[key] signed_headers[key.lower()] = request.headers[key]
if "host" in signed_headers: if "host" in signed_headers:

View File

@ -69,7 +69,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
""" """
Code block mode wrapper for invoking large language model Code block mode wrapper for invoking large language model
""" """
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
response_format = model_parameters["response_format"] response_format = model_parameters["response_format"]
stop = stop or [] stop = stop or []
self._transform_json_prompts( self._transform_json_prompts(

View File

@ -103,7 +103,7 @@ class XinferenceHelper:
model_handle_type = "embedding" model_handle_type = "embedding"
elif response_json.get("model_type") == "audio": elif response_json.get("model_type") == "audio":
model_handle_type = "audio" model_handle_type = "audio"
if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]: if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}:
model_ability.append("text-to-audio") model_ability.append("text-to-audio")
else: else:
model_ability.append("audio-to-text") model_ability.append("audio-to-text")

View File

@ -186,10 +186,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
new_prompt_messages: list[PromptMessage] = [] new_prompt_messages: list[PromptMessage] = []
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
copy_prompt_message = prompt_message.copy() copy_prompt_message = prompt_message.copy()
if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
if isinstance(copy_prompt_message.content, list): if isinstance(copy_prompt_message.content, list):
# check if model is 'glm-4v' # check if model is 'glm-4v'
if model not in ("glm-4v", "glm-4v-plus"): if model not in {"glm-4v", "glm-4v-plus"}:
# not support list message # not support list message
continue continue
# get image and # get image and
@ -209,10 +209,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
): ):
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
else: else:
if ( if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}:
copy_prompt_message.role == PromptMessageRole.USER
or copy_prompt_message.role == PromptMessageRole.TOOL
):
new_prompt_messages.append(copy_prompt_message) new_prompt_messages.append(copy_prompt_message)
elif copy_prompt_message.role == PromptMessageRole.SYSTEM: elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
@ -226,7 +223,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
else: else:
new_prompt_messages.append(copy_prompt_message) new_prompt_messages.append(copy_prompt_message)
if model == "glm-4v" or model == "glm-4v-plus": if model in {"glm-4v", "glm-4v-plus"}:
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
else: else:
params = {"model": model, "messages": [], **model_parameters} params = {"model": model, "messages": [], **model_parameters}
@ -270,11 +267,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
# chatglm model # chatglm model
for prompt_message in new_prompt_messages: for prompt_message in new_prompt_messages:
# merge system message to user message # merge system message to user message
if ( if prompt_message.role in {
prompt_message.role == PromptMessageRole.SYSTEM PromptMessageRole.SYSTEM,
or prompt_message.role == PromptMessageRole.TOOL PromptMessageRole.TOOL,
or prompt_message.role == PromptMessageRole.USER PromptMessageRole.USER,
): }:
if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
params["messages"][-1]["content"] += "\n\n" + prompt_message.content params["messages"][-1]["content"] += "\n\n" + prompt_message.content
else: else:

View File

@ -1,5 +1,4 @@
from __future__ import annotations from __future__ import annotations
from .fine_tuning_job import FineTuningJob as FineTuningJob from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob
from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob from .fine_tuning_job_event import FineTuningJobEvent
from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent

View File

@ -75,7 +75,7 @@ class CommonValidator:
if not isinstance(value, str): if not isinstance(value, str):
raise ValueError(f"Variable {credential_form_schema.variable} should be string") raise ValueError(f"Variable {credential_form_schema.variable} should be string")
if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}:
# If the value is in options, no validation is performed # If the value is in options, no validation is performed
if credential_form_schema.options: if credential_form_schema.options:
if value not in [option.value for option in credential_form_schema.options]: if value not in [option.value for option in credential_form_schema.options]:
@ -83,7 +83,7 @@ class CommonValidator:
if credential_form_schema.type == FormType.SWITCH: if credential_form_schema.type == FormType.SWITCH:
# If the value is not in ['true', 'false'], an exception is thrown # If the value is not in ['true', 'false'], an exception is thrown
if value.lower() not in ["true", "false"]: if value.lower() not in {"true", "false"}:
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
value = True if value.lower() == "true" else False value = True if value.lower() == "true" else False

View File

@ -51,7 +51,7 @@ class ElasticSearchVector(BaseVector):
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try: try:
parsed_url = urlparse(config.host) parsed_url = urlparse(config.host)
if parsed_url.scheme in ["http", "https"]: if parsed_url.scheme in {"http", "https"}:
hosts = f"{config.host}:{config.port}" hosts = f"{config.host}:{config.port}"
else: else:
hosts = f"http://{config.host}:{config.port}" hosts = f"http://{config.host}:{config.port}"
@ -94,7 +94,7 @@ class ElasticSearchVector(BaseVector):
return uuids return uuids
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
return self._client.exists(index=self._collection_name, id=id).__bool__() return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
for id in ids: for id in ids:

View File

@ -35,7 +35,7 @@ class MyScaleVector(BaseVector):
super().__init__(collection_name) super().__init__(collection_name)
self._config = config self._config = config
self._metric = metric self._metric = metric
self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC
self._client = get_client( self._client = get_client(
host=config.host, host=config.host,
port=config.port, port=config.port,
@ -92,7 +92,7 @@ class MyScaleVector(BaseVector):
@staticmethod @staticmethod
def escape_str(value: Any) -> str: def escape_str(value: Any) -> str:
return "".join(" " if c in ("\\", "'") else c for c in str(value)) return "".join(" " if c in {"\\", "'"} else c for c in str(value))
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")

View File

@ -223,15 +223,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query) words = pseg.cut(query)
current_entity = "" current_entity = ""
for word, pos in words: for word, pos in words:
if ( if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
pos == "nr"
or pos == "Ng"
or pos == "eng"
or pos == "nz"
or pos == "n"
or pos == "ORG"
or pos == "v"
): # nr: 人名, ns: 地名, nt: 机构名
current_entity += word current_entity += word
else: else:
if current_entity: if current_entity:

View File

@ -98,17 +98,17 @@ class ExtractProcessor:
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
if etl_type == "Unstructured": if etl_type == "Unstructured":
if file_extension == ".xlsx" or file_extension == ".xls": if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path) extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf": elif file_extension == ".pdf":
extractor = PdfExtractor(file_path) extractor = PdfExtractor(file_path)
elif file_extension in [".md", ".markdown"]: elif file_extension in {".md", ".markdown"}:
extractor = ( extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url) UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
if is_automatic if is_automatic
else MarkdownExtractor(file_path, autodetect_encoding=True) else MarkdownExtractor(file_path, autodetect_encoding=True)
) )
elif file_extension in [".htm", ".html"]: elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path) extractor = HtmlExtractor(file_path)
elif file_extension == ".docx": elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
@ -134,13 +134,13 @@ class ExtractProcessor:
else TextExtractor(file_path, autodetect_encoding=True) else TextExtractor(file_path, autodetect_encoding=True)
) )
else: else:
if file_extension == ".xlsx" or file_extension == ".xls": if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path) extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf": elif file_extension == ".pdf":
extractor = PdfExtractor(file_path) extractor = PdfExtractor(file_path)
elif file_extension in [".md", ".markdown"]: elif file_extension in {".md", ".markdown"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True) extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in [".htm", ".html"]: elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path) extractor = HtmlExtractor(file_path)
elif file_extension == ".docx": elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)

View File

@ -32,7 +32,7 @@ class FirecrawlApp:
else: else:
raise Exception(f'Failed to scrape URL. Error: {response["error"]}') raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
elif response.status_code in [402, 409, 500]: elif response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred") error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}") raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
else: else:

View File

@ -103,12 +103,12 @@ class NotionExtractor(BaseExtractor):
multi_select_list = property_value[type] multi_select_list = property_value[type]
for multi_select in multi_select_list: for multi_select in multi_select_list:
value.append(multi_select["name"]) value.append(multi_select["name"])
elif type == "rich_text" or type == "title": elif type in {"rich_text", "title"}:
if len(property_value[type]) > 0: if len(property_value[type]) > 0:
value = property_value[type][0]["plain_text"] value = property_value[type][0]["plain_text"]
else: else:
value = "" value = ""
elif type == "select" or type == "status": elif type in {"select", "status"}:
if property_value[type]: if property_value[type]:
value = property_value[type]["name"] value = property_value[type]["name"]
else: else:

View File

@ -115,7 +115,7 @@ class DatasetRetrieval:
available_datasets.append(dataset) available_datasets.append(dataset)
all_documents = [] all_documents = []
user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
all_documents = self.single_retrieve( all_documents = self.single_retrieve(
app_id, app_id,

View File

@ -35,7 +35,7 @@ def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> l
splits = re.split(separator, text) splits = re.split(separator, text)
else: else:
splits = list(text) splits = list(text)
return [s for s in splits if (s != "" and s != "\n")] return [s for s in splits if (s not in {"", "\n"})]
class TextSplitter(BaseDocumentTransformer, ABC): class TextSplitter(BaseDocumentTransformer, ABC):

View File

@ -68,7 +68,7 @@ class AppToolProviderEntity(ToolProviderController):
label = input_form[form_type]["label"] label = input_form[form_type]["label"]
variable_name = input_form[form_type]["variable_name"] variable_name = input_form[form_type]["variable_name"]
options = input_form[form_type].get("options", []) options = input_form[form_type].get("options", [])
if form_type == "paragraph" or form_type == "text-input": if form_type in {"paragraph", "text-input"}:
tool["parameters"].append( tool["parameters"].append(
ToolParameter( ToolParameter(
name=variable_name, name=variable_name,

View File

@ -168,7 +168,7 @@ class AIPPTGenerateTool(BuiltinTool):
pass pass
elif event == "close": elif event == "close":
break break
elif event == "error" or event == "filter": elif event in {"error", "filter"}:
raise Exception(f"Failed to generate outline: {data}") raise Exception(f"Failed to generate outline: {data}")
return outline return outline
@ -213,7 +213,7 @@ class AIPPTGenerateTool(BuiltinTool):
pass pass
elif event == "close": elif event == "close":
break break
elif event == "error" or event == "filter": elif event in {"error", "filter"}:
raise Exception(f"Failed to generate content: {data}") raise Exception(f"Failed to generate content: {data}")
return content return content

View File

@ -39,11 +39,11 @@ class DallE3Tool(BuiltinTool):
n = tool_parameters.get("n", 1) n = tool_parameters.get("n", 1)
# get quality # get quality
quality = tool_parameters.get("quality", "standard") quality = tool_parameters.get("quality", "standard")
if quality not in ["standard", "hd"]: if quality not in {"standard", "hd"}:
return self.create_text_message("Invalid quality") return self.create_text_message("Invalid quality")
# get style # get style
style = tool_parameters.get("style", "vivid") style = tool_parameters.get("style", "vivid")
if style not in ["natural", "vivid"]: if style not in {"natural", "vivid"}:
return self.create_text_message("Invalid style") return self.create_text_message("Invalid style")
# set extra body # set extra body
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))

View File

@ -14,7 +14,7 @@ class SimpleCode(BuiltinTool):
language = tool_parameters.get("language", CodeLanguage.PYTHON3) language = tool_parameters.get("language", CodeLanguage.PYTHON3)
code = tool_parameters.get("code", "") code = tool_parameters.get("code", "")
if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]: if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
raise ValueError(f"Only python3 and javascript are supported, not {language}") raise ValueError(f"Only python3 and javascript are supported, not {language}")
result = CodeExecutor.execute_code(language, "", code) result = CodeExecutor.execute_code(language, "", code)

View File

@ -34,11 +34,11 @@ class CogView3Tool(BuiltinTool):
n = tool_parameters.get("n", 1) n = tool_parameters.get("n", 1)
# get quality # get quality
quality = tool_parameters.get("quality", "standard") quality = tool_parameters.get("quality", "standard")
if quality not in ["standard", "hd"]: if quality not in {"standard", "hd"}:
return self.create_text_message("Invalid quality") return self.create_text_message("Invalid quality")
# get style # get style
style = tool_parameters.get("style", "vivid") style = tool_parameters.get("style", "vivid")
if style not in ["natural", "vivid"]: if style not in {"natural", "vivid"}:
return self.create_text_message("Invalid style") return self.create_text_message("Invalid style")
# set extra body # set extra body
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))

View File

@ -49,11 +49,11 @@ class DallE3Tool(BuiltinTool):
n = tool_parameters.get("n", 1) n = tool_parameters.get("n", 1)
# get quality # get quality
quality = tool_parameters.get("quality", "standard") quality = tool_parameters.get("quality", "standard")
if quality not in ["standard", "hd"]: if quality not in {"standard", "hd"}:
return self.create_text_message("Invalid quality") return self.create_text_message("Invalid quality")
# get style # get style
style = tool_parameters.get("style", "vivid") style = tool_parameters.get("style", "vivid")
if style not in ["natural", "vivid"]: if style not in {"natural", "vivid"}:
return self.create_text_message("Invalid style") return self.create_text_message("Invalid style")
# call openapi dalle3 # call openapi dalle3

View File

@ -133,9 +133,9 @@ class GetWorksheetFieldsTool(BuiltinTool):
def _extract_options(self, control: dict) -> list: def _extract_options(self, control: dict) -> list:
options = [] options = []
if control["type"] in [9, 10, 11]: if control["type"] in {9, 10, 11}:
options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])])
elif control["type"] in [28, 36]: elif control["type"] in {28, 36}:
itemnames = control["advancedSetting"].get("itemnames") itemnames = control["advancedSetting"].get("itemnames")
if itemnames and itemnames.startswith("[{"): if itemnames and itemnames.startswith("[{"):
try: try:

View File

@ -183,11 +183,11 @@ class ListWorksheetRecordsTool(BuiltinTool):
type_id = field.get("typeId") type_id = field.get("typeId")
if type_id == 10: if type_id == 10:
value = value if isinstance(value, str) else "".join(value) value = value if isinstance(value, str) else "".join(value)
elif type_id in [28, 36]: elif type_id in {28, 36}:
value = field.get("options", {}).get(value, value) value = field.get("options", {}).get(value, value)
elif type_id in [26, 27, 48, 14]: elif type_id in {26, 27, 48, 14}:
value = self.process_value(value) value = self.process_value(value)
elif type_id in [35, 29]: elif type_id in {35, 29}:
value = self.parse_cascade_or_associated(field, value) value = self.parse_cascade_or_associated(field, value)
elif type_id == 40: elif type_id == 40:
value = self.parse_location(value) value = self.parse_location(value)

View File

@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool):
models_data=[], models_data=[],
headers=headers, headers=headers,
params=params, params=params,
recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"), recursive=result_type not in {"first sd_name", "first name sd_name pair"},
) )
result_str = "" result_str = ""

View File

@ -38,7 +38,7 @@ class SearchAPI:
return { return {
"engine": "google", "engine": "google",
"q": query, "q": query,
**{key: value for key, value in kwargs.items() if value not in [None, ""]}, **{key: value for key, value in kwargs.items() if value not in {None, ""}},
} }
@staticmethod @staticmethod

View File

@ -38,7 +38,7 @@ class SearchAPI:
return { return {
"engine": "google_jobs", "engine": "google_jobs",
"q": query, "q": query,
**{key: value for key, value in kwargs.items() if value not in [None, ""]}, **{key: value for key, value in kwargs.items() if value not in {None, ""}},
} }
@staticmethod @staticmethod

View File

@ -38,7 +38,7 @@ class SearchAPI:
return { return {
"engine": "google_news", "engine": "google_news",
"q": query, "q": query,
**{key: value for key, value in kwargs.items() if value not in [None, ""]}, **{key: value for key, value in kwargs.items() if value not in {None, ""}},
} }
@staticmethod @staticmethod

View File

@ -38,7 +38,7 @@ class SearchAPI:
"engine": "youtube_transcripts", "engine": "youtube_transcripts",
"video_id": video_id, "video_id": video_id,
"lang": language or "en", "lang": language or "en",
**{key: value for key, value in kwargs.items() if value not in [None, ""]}, **{key: value for key, value in kwargs.items() if value not in {None, ""}},
} }
@staticmethod @staticmethod

View File

@ -214,7 +214,7 @@ class Spider:
return requests.delete(url, headers=headers, stream=stream) return requests.delete(url, headers=headers, stream=stream)
def _handle_error(self, response, action): def _handle_error(self, response, action):
if response.status_code in [402, 409, 500]: if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred") error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
else: else:

View File

@ -32,7 +32,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
model = tool_parameters.get("model", "core") model = tool_parameters.get("model", "core")
if model in ["sd3", "sd3-turbo"]: if model in {"sd3", "sd3-turbo"}:
payload["model"] = tool_parameters.get("model") payload["model"] = tool_parameters.get("model")
if model != "sd3-turbo": if model != "sd3-turbo":

View File

@ -38,7 +38,7 @@ class VannaTool(BuiltinTool):
vn = VannaDefault(model=model, api_key=api_key) vn = VannaDefault(model=model, api_key=api_key)
db_type = tool_parameters.get("db_type", "") db_type = tool_parameters.get("db_type", "")
if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]: if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:
if not db_name: if not db_name:
return self.create_text_message("Please input database name") return self.create_text_message("Please input database name")
if not username: if not username:

View File

@ -19,7 +19,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
class BuiltinToolProviderController(ToolProviderController): class BuiltinToolProviderController(ToolProviderController):
def __init__(self, **data: Any) -> None: def __init__(self, **data: Any) -> None:
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}:
super().__init__(**data) super().__init__(**data)
return return

View File

@ -153,10 +153,10 @@ class ToolProviderController(BaseModel, ABC):
# check type # check type
credential_schema = credentials_need_to_validate[credential_name] credential_schema = credentials_need_to_validate[credential_name]
if ( if credential_schema in {
credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT ToolProviderCredentials.CredentialsType.SECRET_INPUT,
or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT ToolProviderCredentials.CredentialsType.TEXT_INPUT,
): }:
if not isinstance(credentials[credential_name], str): if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
@ -184,11 +184,11 @@ class ToolProviderController(BaseModel, ABC):
if credential_schema.default is not None: if credential_schema.default is not None:
default_value = credential_schema.default default_value = credential_schema.default
# parse default value into the correct type # parse default value into the correct type
if ( if credential_schema.type in {
credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT ToolProviderCredentials.CredentialsType.SECRET_INPUT,
or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT ToolProviderCredentials.CredentialsType.TEXT_INPUT,
or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT ToolProviderCredentials.CredentialsType.SELECT,
): }:
default_value = str(default_value) default_value = str(default_value)
credentials[credential_name] = default_value credentials[credential_name] = default_value

View File

@ -5,7 +5,7 @@ from urllib.parse import urlencode
import httpx import httpx
import core.helper.ssrf_proxy as ssrf_proxy from core.helper import ssrf_proxy
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
@ -191,7 +191,7 @@ class ApiTool(Tool):
else: else:
body = body body = body
if method in ("get", "head", "post", "put", "delete", "patch"): if method in {"get", "head", "post", "put", "delete", "patch"}:
response = getattr(ssrf_proxy, method)( response = getattr(ssrf_proxy, method)(
url, url,
params=params, params=params,
@ -224,9 +224,9 @@ class ApiTool(Tool):
elif option["type"] == "string": elif option["type"] == "string":
return str(value) return str(value)
elif option["type"] == "boolean": elif option["type"] == "boolean":
if str(value).lower() in ["true", "1"]: if str(value).lower() in {"true", "1"}:
return True return True
elif str(value).lower() in ["false", "0"]: elif str(value).lower() in {"false", "0"}:
return False return False
else: else:
continue # Not a boolean, try next option continue # Not a boolean, try next option

View File

@ -189,10 +189,7 @@ class ToolEngine:
result += response.message result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK: elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please tell user to check it." result += f"result link: {response.message}. please tell user to check it."
elif ( elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
or response.type == ToolInvokeMessage.MessageType.IMAGE
):
result += ( result += (
"image has been created and sent to user already, you do not need to create it," "image has been created and sent to user already, you do not need to create it,"
" just tell the user to check it now." " just tell the user to check it now."
@ -212,10 +209,7 @@ class ToolEngine:
result = [] result = []
for response in tool_response: for response in tool_response:
if ( if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
or response.type == ToolInvokeMessage.MessageType.IMAGE
):
mimetype = None mimetype = None
if response.meta.get("mime_type"): if response.meta.get("mime_type"):
mimetype = response.meta.get("mime_type") mimetype = response.meta.get("mime_type")
@ -297,7 +291,7 @@ class ToolEngine:
belongs_to="assistant", belongs_to="assistant",
url=message.url, url=message.url,
upload_file_id=None, upload_file_id=None,
created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"), created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"),
created_by=user_id, created_by=user_id,
) )

View File

@ -19,7 +19,7 @@ class ToolFileMessageTransformer:
result = [] result = []
for message in messages: for message in messages:
if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK: if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
result.append(message) result.append(message)
elif message.type == ToolInvokeMessage.MessageType.IMAGE: elif message.type == ToolInvokeMessage.MessageType.IMAGE:
# try to download image # try to download image

View File

@ -165,7 +165,7 @@ class ApiBasedToolSchemaParser:
elif "schema" in parameter and "type" in parameter["schema"]: elif "schema" in parameter and "type" in parameter["schema"]:
typ = parameter["schema"]["type"] typ = parameter["schema"]["type"]
if typ == "integer" or typ == "number": if typ in {"integer", "number"}:
return ToolParameter.ToolParameterType.NUMBER return ToolParameter.ToolParameterType.NUMBER
elif typ == "boolean": elif typ == "boolean":
return ToolParameter.ToolParameterType.BOOLEAN return ToolParameter.ToolParameterType.BOOLEAN

View File

@ -313,7 +313,7 @@ def normalize_whitespace(text):
def is_leaf(element): def is_leaf(element):
return element.name in ["p", "li"] return element.name in {"p", "li"}
def is_text(element): def is_text(element):

View File

@ -51,7 +51,7 @@ class RouteNodeState(BaseModel):
:param run_result: run result :param run_result: run result
""" """
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}:
raise Exception(f"Route state {self.id} already finished") raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:

View File

@ -148,11 +148,11 @@ class AnswerStreamGeneratorRouter:
for edge in reverse_edges: for edge in reverse_edges:
source_node_id = edge.source_node_id source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in ( if source_node_type in {
NodeType.ANSWER.value, NodeType.ANSWER.value,
NodeType.IF_ELSE.value, NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value, NodeType.QUESTION_CLASSIFIER.value,
): }:
answer_dependencies[answer_node_id].append(source_node_id) answer_dependencies[answer_node_id].append(source_node_id)
else: else:
cls._recursive_fetch_answer_dependencies( cls._recursive_fetch_answer_dependencies(

View File

@ -136,10 +136,10 @@ class EndStreamGeneratorRouter:
for edge in reverse_edges: for edge in reverse_edges:
source_node_id = edge.source_node_id source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in ( if source_node_type in {
NodeType.IF_ELSE.value, NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER, NodeType.QUESTION_CLASSIFIER,
): }:
end_dependencies[end_node_id].append(source_node_id) end_dependencies[end_node_id].append(source_node_id)
else: else:
cls._recursive_fetch_end_dependencies( cls._recursive_fetch_end_dependencies(

View File

@ -6,8 +6,8 @@ from urllib.parse import urlencode
import httpx import httpx
import core.helper.ssrf_proxy as ssrf_proxy
from configs import dify_config from configs import dify_config
from core.helper import ssrf_proxy
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request.entities import ( from core.workflow.nodes.http_request.entities import (
@ -176,7 +176,7 @@ class HttpExecutor:
elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
self.headers["Content-Type"] = "application/x-www-form-urlencoded" self.headers["Content-Type"] = "application/x-www-form-urlencoded"
if node_data.body.type in ["form-data", "x-www-form-urlencoded"]: if node_data.body.type in {"form-data", "x-www-form-urlencoded"}:
body = self._to_dict(body_data) body = self._to_dict(body_data)
if node_data.body.type == "form-data": if node_data.body.type == "form-data":
@ -187,7 +187,7 @@ class HttpExecutor:
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
else: else:
self.body = urlencode(body) self.body = urlencode(body)
elif node_data.body.type in ["json", "raw-text"]: elif node_data.body.type in {"json", "raw-text"}:
self.body = body_data self.body = body_data
elif node_data.body.type == "none": elif node_data.body.type == "none":
self.body = "" self.body = ""
@ -258,7 +258,7 @@ class HttpExecutor:
"follow_redirects": True, "follow_redirects": True,
} }
if self.method in ("get", "head", "post", "put", "delete", "patch"): if self.method in {"get", "head", "post", "put", "delete", "patch"}:
response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
else: else:
raise ValueError(f"Invalid http method {self.method}") raise ValueError(f"Invalid http method {self.method}")

Some files were not shown because too many files have changed in this diff Show More