mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
chore: refurish python code by applying Pylint linter rules (#8322)
This commit is contained in:
parent
1ab81b4972
commit
a1104ab97e
|
@ -164,7 +164,7 @@ def initialize_extensions(app):
|
||||||
@login_manager.request_loader
|
@login_manager.request_loader
|
||||||
def load_user_from_request(request_from_flask_login):
|
def load_user_from_request(request_from_flask_login):
|
||||||
"""Load user based on the request."""
|
"""Load user based on the request."""
|
||||||
if request.blueprint not in ["console", "inner_api"]:
|
if request.blueprint not in {"console", "inner_api"}:
|
||||||
return None
|
return None
|
||||||
# Check if the user_id contains a dot, indicating the old format
|
# Check if the user_id contains a dot, indicating the old format
|
||||||
auth_header = request.headers.get("Authorization", "")
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
|
|
@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
|
||||||
@click.command("vdb-migrate", help="migrate vector db.")
|
@click.command("vdb-migrate", help="migrate vector db.")
|
||||||
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
|
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
|
||||||
def vdb_migrate(scope: str):
|
def vdb_migrate(scope: str):
|
||||||
if scope in ["knowledge", "all"]:
|
if scope in {"knowledge", "all"}:
|
||||||
migrate_knowledge_vector_database()
|
migrate_knowledge_vector_database()
|
||||||
if scope in ["annotation", "all"]:
|
if scope in {"annotation", "all"}:
|
||||||
migrate_annotation_vector_database()
|
migrate_annotation_vector_database()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ class ChatMessageTextApi(Resource):
|
||||||
message_id = args.get("message_id", None)
|
message_id = args.get("message_id", None)
|
||||||
text = args.get("text", None)
|
text = args.get("text", None)
|
||||||
if (
|
if (
|
||||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||||
and app_model.workflow
|
and app_model.workflow
|
||||||
and app_model.workflow.features_dict
|
and app_model.workflow.features_dict
|
||||||
):
|
):
|
||||||
|
|
|
@ -71,7 +71,7 @@ class OAuthCallback(Resource):
|
||||||
|
|
||||||
account = _generate_account(provider, user_info)
|
account = _generate_account(provider, user_info)
|
||||||
# Check account status
|
# Check account status
|
||||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
|
||||||
return {"error": "Account is banned or closed."}, 403
|
return {"error": "Account is banned or closed."}, 403
|
||||||
|
|
||||||
if account.status == AccountStatus.PENDING.value:
|
if account.status == AccountStatus.PENDING.value:
|
||||||
|
|
|
@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
|
|
||||||
if document.indexing_status in ["completed", "error"]:
|
if document.indexing_status in {"completed", "error"}:
|
||||||
raise DocumentAlreadyFinishedError()
|
raise DocumentAlreadyFinishedError()
|
||||||
|
|
||||||
data_process_rule = document.dataset_process_rule
|
data_process_rule = document.dataset_process_rule
|
||||||
|
@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||||
info_list = []
|
info_list = []
|
||||||
extract_settings = []
|
extract_settings = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
if document.indexing_status in ["completed", "error"]:
|
if document.indexing_status in {"completed", "error"}:
|
||||||
raise DocumentAlreadyFinishedError()
|
raise DocumentAlreadyFinishedError()
|
||||||
data_source_info = document.data_source_info_dict
|
data_source_info = document.data_source_info_dict
|
||||||
# format document files info
|
# format document files info
|
||||||
|
@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
elif action == "resume":
|
elif action == "resume":
|
||||||
if document.indexing_status not in ["paused", "error"]:
|
if document.indexing_status not in {"paused", "error"}:
|
||||||
raise InvalidActionError("Document not in paused or error state.")
|
raise InvalidActionError("Document not in paused or error state.")
|
||||||
|
|
||||||
document.paused_by = None
|
document.paused_by = None
|
||||||
|
|
|
@ -81,7 +81,7 @@ class ChatTextApi(InstalledAppResource):
|
||||||
message_id = args.get("message_id", None)
|
message_id = args.get("message_id", None)
|
||||||
text = args.get("text", None)
|
text = args.get("text", None)
|
||||||
if (
|
if (
|
||||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||||
and app_model.workflow
|
and app_model.workflow
|
||||||
and app_model.workflow.features_dict
|
and app_model.workflow.features_dict
|
||||||
):
|
):
|
||||||
|
|
|
@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
|
||||||
def post(self, installed_app, task_id):
|
def post(self, installed_app, task_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||||
|
|
|
@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
|
||||||
def get(self, installed_app):
|
def get(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
|
||||||
def delete(self, installed_app, c_id):
|
def delete(self, installed_app, c_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
|
||||||
def post(self, installed_app, c_id):
|
def post(self, installed_app, c_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
|
||||||
def patch(self, installed_app, c_id):
|
def patch(self, installed_app, c_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
|
||||||
def patch(self, installed_app, c_id):
|
def patch(self, installed_app, c_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
|
@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource):
|
||||||
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
|
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
|
||||||
"is_pinned": installed_app.is_pinned,
|
"is_pinned": installed_app.is_pinned,
|
||||||
"last_used_at": installed_app.last_used_at,
|
"last_used_at": installed_app.last_used_at,
|
||||||
"editable": current_user.role in ["owner", "admin"],
|
"editable": current_user.role in {"owner", "admin"},
|
||||||
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
|
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
|
||||||
}
|
}
|
||||||
for installed_app in installed_apps
|
for installed_app in installed_apps
|
||||||
|
|
|
@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||||
def get(self, installed_app, message_id):
|
def get(self, installed_app, message_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
|
@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
|
||||||
"""Retrieve app parameters."""
|
"""Retrieve app parameters."""
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
||||||
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||||
workflow = app_model.workflow
|
workflow = app_model.workflow
|
||||||
if workflow is None:
|
if workflow is None:
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
|
|
|
@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
|
||||||
raise TooManyFilesError()
|
raise TooManyFilesError()
|
||||||
|
|
||||||
extension = file.filename.split(".")[-1]
|
extension = file.filename.split(".")[-1]
|
||||||
if extension.lower() not in ["svg", "png"]:
|
if extension.lower() not in {"svg", "png"}:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -42,7 +42,7 @@ class AppParameterApi(Resource):
|
||||||
@marshal_with(parameters_fields)
|
@marshal_with(parameters_fields)
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""Retrieve app parameters."""
|
"""Retrieve app parameters."""
|
||||||
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||||
workflow = app_model.workflow
|
workflow = app_model.workflow
|
||||||
if workflow is None:
|
if workflow is None:
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
|
|
|
@ -79,7 +79,7 @@ class TextApi(Resource):
|
||||||
message_id = args.get("message_id", None)
|
message_id = args.get("message_id", None)
|
||||||
text = args.get("text", None)
|
text = args.get("text", None)
|
||||||
if (
|
if (
|
||||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||||
and app_model.workflow
|
and app_model.workflow
|
||||||
and app_model.workflow.features_dict
|
and app_model.workflow.features_dict
|
||||||
):
|
):
|
||||||
|
|
|
@ -96,7 +96,7 @@ class ChatApi(Resource):
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||||
def post(self, app_model: App, end_user: EndUser):
|
def post(self, app_model: App, end_user: EndUser):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
@ -144,7 +144,7 @@ class ChatStopApi(Resource):
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||||
|
|
|
@ -18,7 +18,7 @@ class ConversationApi(Resource):
|
||||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||||
def get(self, app_model: App, end_user: EndUser):
|
def get(self, app_model: App, end_user: EndUser):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
|
||||||
@marshal_with(simple_conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
|
||||||
@marshal_with(simple_conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
|
@ -76,7 +76,7 @@ class MessageListApi(Resource):
|
||||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||||
def get(self, app_model: App, end_user: EndUser):
|
def get(self, app_model: App, end_user: EndUser):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
|
||||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
|
||||||
@marshal_with(parameters_fields)
|
@marshal_with(parameters_fields)
|
||||||
def get(self, app_model: App, end_user):
|
def get(self, app_model: App, end_user):
|
||||||
"""Retrieve app parameters."""
|
"""Retrieve app parameters."""
|
||||||
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||||
workflow = app_model.workflow
|
workflow = app_model.workflow
|
||||||
if workflow is None:
|
if workflow is None:
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
|
|
|
@ -78,7 +78,7 @@ class TextApi(WebApiResource):
|
||||||
message_id = args.get("message_id", None)
|
message_id = args.get("message_id", None)
|
||||||
text = args.get("text", None)
|
text = args.get("text", None)
|
||||||
if (
|
if (
|
||||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||||
and app_model.workflow
|
and app_model.workflow
|
||||||
and app_model.workflow.features_dict
|
and app_model.workflow.features_dict
|
||||||
):
|
):
|
||||||
|
|
|
@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
|
||||||
class ChatApi(WebApiResource):
|
class ChatApi(WebApiResource):
|
||||||
def post(self, app_model, end_user):
|
def post(self, app_model, end_user):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
|
||||||
class ChatStopApi(WebApiResource):
|
class ChatStopApi(WebApiResource):
|
||||||
def post(self, app_model, end_user, task_id):
|
def post(self, app_model, end_user, task_id):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||||
|
|
|
@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
|
||||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||||
def get(self, app_model, end_user):
|
def get(self, app_model, end_user):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
|
||||||
class ConversationApi(WebApiResource):
|
class ConversationApi(WebApiResource):
|
||||||
def delete(self, app_model, end_user, c_id):
|
def delete(self, app_model, end_user, c_id):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
|
||||||
@marshal_with(simple_conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
def post(self, app_model, end_user, c_id):
|
def post(self, app_model, end_user, c_id):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
|
||||||
class ConversationPinApi(WebApiResource):
|
class ConversationPinApi(WebApiResource):
|
||||||
def patch(self, app_model, end_user, c_id):
|
def patch(self, app_model, end_user, c_id):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
|
||||||
class ConversationUnPinApi(WebApiResource):
|
class ConversationUnPinApi(WebApiResource):
|
||||||
def patch(self, app_model, end_user, c_id):
|
def patch(self, app_model, end_user, c_id):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
|
@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
|
||||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||||
def get(self, app_model, end_user):
|
def get(self, app_model, end_user):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||||
class MessageSuggestedQuestionApi(WebApiResource):
|
class MessageSuggestedQuestionApi(WebApiResource):
|
||||||
def get(self, app_model, end_user, message_id):
|
def get(self, app_model, end_user, message_id):
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
|
@ -90,7 +90,7 @@ class CotAgentOutputParser:
|
||||||
|
|
||||||
if not in_code_block and not in_json:
|
if not in_code_block and not in_json:
|
||||||
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||||
if last_character not in ["\n", " ", ""]:
|
if last_character not in {"\n", " ", ""}:
|
||||||
index += steps
|
index += steps
|
||||||
yield delta
|
yield delta
|
||||||
continue
|
continue
|
||||||
|
@ -117,7 +117,7 @@ class CotAgentOutputParser:
|
||||||
action_idx = 0
|
action_idx = 0
|
||||||
|
|
||||||
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||||
if last_character not in ["\n", " ", ""]:
|
if last_character not in {"\n", " ", ""}:
|
||||||
index += steps
|
index += steps
|
||||||
yield delta
|
yield delta
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -29,7 +29,7 @@ class BaseAppConfigManager:
|
||||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
|
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
|
||||||
|
|
||||||
additional_features.file_upload = FileUploadConfigManager.convert(
|
additional_features.file_upload = FileUploadConfigManager.convert(
|
||||||
config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
|
config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
|
||||||
)
|
)
|
||||||
|
|
||||||
additional_features.opening_statement, additional_features.suggested_questions = (
|
additional_features.opening_statement, additional_features.suggested_questions = (
|
||||||
|
|
|
@ -18,7 +18,7 @@ class AgentConfigManager:
|
||||||
|
|
||||||
if agent_strategy == "function_call":
|
if agent_strategy == "function_call":
|
||||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||||
elif agent_strategy == "cot" or agent_strategy == "react":
|
elif agent_strategy in {"cot", "react"}:
|
||||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||||
else:
|
else:
|
||||||
# old configs, try to detect default strategy
|
# old configs, try to detect default strategy
|
||||||
|
@ -43,10 +43,10 @@ class AgentConfigManager:
|
||||||
|
|
||||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||||
|
|
||||||
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [
|
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
|
||||||
"react_router",
|
"react_router",
|
||||||
"router",
|
"router",
|
||||||
]:
|
}:
|
||||||
agent_prompt = agent_dict.get("prompt", None) or {}
|
agent_prompt = agent_dict.get("prompt", None) or {}
|
||||||
# check model mode
|
# check model mode
|
||||||
model_mode = config.get("model", {}).get("mode", "completion")
|
model_mode = config.get("model", {}).get("mode", "completion")
|
||||||
|
|
|
@ -167,7 +167,7 @@ class DatasetConfigManager:
|
||||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||||
|
|
||||||
has_datasets = False
|
has_datasets = False
|
||||||
if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]:
|
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
|
||||||
for tool in config["agent_mode"]["tools"]:
|
for tool in config["agent_mode"]["tools"]:
|
||||||
key = list(tool.keys())[0]
|
key = list(tool.keys())[0]
|
||||||
if key == "dataset":
|
if key == "dataset":
|
||||||
|
|
|
@ -42,12 +42,12 @@ class BasicVariablesConfigManager:
|
||||||
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif variable_type in [
|
elif variable_type in {
|
||||||
VariableEntityType.TEXT_INPUT,
|
VariableEntityType.TEXT_INPUT,
|
||||||
VariableEntityType.PARAGRAPH,
|
VariableEntityType.PARAGRAPH,
|
||||||
VariableEntityType.NUMBER,
|
VariableEntityType.NUMBER,
|
||||||
VariableEntityType.SELECT,
|
VariableEntityType.SELECT,
|
||||||
]:
|
}:
|
||||||
variable = variables[variable_type]
|
variable = variables[variable_type]
|
||||||
variable_entities.append(
|
variable_entities.append(
|
||||||
VariableEntity(
|
VariableEntity(
|
||||||
|
@ -97,7 +97,7 @@ class BasicVariablesConfigManager:
|
||||||
variables = []
|
variables = []
|
||||||
for item in config["user_input_form"]:
|
for item in config["user_input_form"]:
|
||||||
key = list(item.keys())[0]
|
key = list(item.keys())[0]
|
||||||
if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]:
|
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
|
||||||
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
|
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
|
||||||
|
|
||||||
form_item = item[key]
|
form_item = item[key]
|
||||||
|
|
|
@ -54,14 +54,14 @@ class FileUploadConfigManager:
|
||||||
|
|
||||||
if is_vision:
|
if is_vision:
|
||||||
detail = config["file_upload"]["image"]["detail"]
|
detail = config["file_upload"]["image"]["detail"]
|
||||||
if detail not in ["high", "low"]:
|
if detail not in {"high", "low"}:
|
||||||
raise ValueError("detail must be in ['high', 'low']")
|
raise ValueError("detail must be in ['high', 'low']")
|
||||||
|
|
||||||
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
|
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
|
||||||
if not isinstance(transfer_methods, list):
|
if not isinstance(transfer_methods, list):
|
||||||
raise ValueError("transfer_methods must be of list type")
|
raise ValueError("transfer_methods must be of list type")
|
||||||
for method in transfer_methods:
|
for method in transfer_methods:
|
||||||
if method not in ["remote_url", "local_file"]:
|
if method not in {"remote_url", "local_file"}:
|
||||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
||||||
|
|
||||||
return config, ["file_upload"]
|
return config, ["file_upload"]
|
||||||
|
|
|
@ -73,7 +73,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
raise ValueError("Workflow not initialized")
|
raise ValueError("Workflow not initialized")
|
||||||
|
|
||||||
user_id = None
|
user_id = None
|
||||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||||
if end_user:
|
if end_user:
|
||||||
user_id = end_user.session_id
|
user_id = end_user.session_id
|
||||||
|
@ -175,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
user_id=self.application_generate_entity.user_id,
|
user_id=self.application_generate_entity.user_id,
|
||||||
user_from=(
|
user_from=(
|
||||||
UserFrom.ACCOUNT
|
UserFrom.ACCOUNT
|
||||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||||
else UserFrom.END_USER
|
else UserFrom.END_USER
|
||||||
),
|
),
|
||||||
invoke_from=self.application_generate_entity.invoke_from,
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
|
|
|
@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
|
||||||
def convert(
|
def convert(
|
||||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||||
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||||
if isinstance(response, AppBlockingResponse):
|
if isinstance(response, AppBlockingResponse):
|
||||||
return cls.convert_blocking_full_response(response)
|
return cls.convert_blocking_full_response(response)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -22,11 +22,11 @@ class BaseAppGenerator:
|
||||||
return var.default or ""
|
return var.default or ""
|
||||||
if (
|
if (
|
||||||
var.type
|
var.type
|
||||||
in (
|
in {
|
||||||
VariableEntityType.TEXT_INPUT,
|
VariableEntityType.TEXT_INPUT,
|
||||||
VariableEntityType.SELECT,
|
VariableEntityType.SELECT,
|
||||||
VariableEntityType.PARAGRAPH,
|
VariableEntityType.PARAGRAPH,
|
||||||
)
|
}
|
||||||
and user_input_value
|
and user_input_value
|
||||||
and not isinstance(user_input_value, str)
|
and not isinstance(user_input_value, str)
|
||||||
):
|
):
|
||||||
|
@ -44,7 +44,7 @@ class BaseAppGenerator:
|
||||||
options = var.options or []
|
options = var.options or []
|
||||||
if user_input_value not in options:
|
if user_input_value not in options:
|
||||||
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
||||||
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
|
||||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ class AppQueueManager:
|
||||||
self._user_id = user_id
|
self._user_id = user_id
|
||||||
self._invoke_from = invoke_from
|
self._invoke_from = invoke_from
|
||||||
|
|
||||||
user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
|
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||||
redis_client.setex(
|
redis_client.setex(
|
||||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||||
)
|
)
|
||||||
|
@ -118,7 +118,7 @@ class AppQueueManager:
|
||||||
if result is None:
|
if result is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
|
user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||||
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
|
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -148,7 +148,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||||
# get from source
|
# get from source
|
||||||
end_user_id = None
|
end_user_id = None
|
||||||
account_id = None
|
account_id = None
|
||||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||||
from_source = "api"
|
from_source = "api"
|
||||||
end_user_id = application_generate_entity.user_id
|
end_user_id = application_generate_entity.user_id
|
||||||
else:
|
else:
|
||||||
|
@ -165,11 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||||
model_provider = application_generate_entity.model_conf.provider
|
model_provider = application_generate_entity.model_conf.provider
|
||||||
model_id = application_generate_entity.model_conf.model
|
model_id = application_generate_entity.model_conf.model
|
||||||
override_model_configs = None
|
override_model_configs = None
|
||||||
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [
|
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in {
|
||||||
AppMode.AGENT_CHAT,
|
AppMode.AGENT_CHAT,
|
||||||
AppMode.CHAT,
|
AppMode.CHAT,
|
||||||
AppMode.COMPLETION,
|
AppMode.COMPLETION,
|
||||||
]:
|
}:
|
||||||
override_model_configs = app_config.app_model_config_dict
|
override_model_configs = app_config.app_model_config_dict
|
||||||
|
|
||||||
# get conversation introduction
|
# get conversation introduction
|
||||||
|
|
|
@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||||
app_config = cast(WorkflowAppConfig, app_config)
|
app_config = cast(WorkflowAppConfig, app_config)
|
||||||
|
|
||||||
user_id = None
|
user_id = None
|
||||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||||
if end_user:
|
if end_user:
|
||||||
user_id = end_user.session_id
|
user_id = end_user.session_id
|
||||||
|
@ -113,7 +113,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||||
user_id=self.application_generate_entity.user_id,
|
user_id=self.application_generate_entity.user_id,
|
||||||
user_from=(
|
user_from=(
|
||||||
UserFrom.ACCOUNT
|
UserFrom.ACCOUNT
|
||||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||||
else UserFrom.END_USER
|
else UserFrom.END_USER
|
||||||
),
|
),
|
||||||
invoke_from=self.application_generate_entity.invoke_from,
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
|
|
|
@ -63,7 +63,7 @@ class AnnotationReplyFeature:
|
||||||
score = documents[0].metadata["score"]
|
score = documents[0].metadata["score"]
|
||||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||||
if annotation:
|
if annotation:
|
||||||
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
|
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
|
||||||
from_source = "api"
|
from_source = "api"
|
||||||
else:
|
else:
|
||||||
from_source = "console"
|
from_source = "console"
|
||||||
|
|
|
@ -372,7 +372,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||||
self._message,
|
self._message,
|
||||||
application_generate_entity=self._application_generate_entity,
|
application_generate_entity=self._application_generate_entity,
|
||||||
conversation=self._conversation,
|
conversation=self._conversation,
|
||||||
is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
|
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
|
||||||
and self._application_generate_entity.conversation_id is None,
|
and self._application_generate_entity.conversation_id is None,
|
||||||
extras=self._application_generate_entity.extras,
|
extras=self._application_generate_entity.extras,
|
||||||
)
|
)
|
||||||
|
|
|
@ -383,7 +383,7 @@ class WorkflowCycleManage:
|
||||||
:param workflow_node_execution: workflow node execution
|
:param workflow_node_execution: workflow node execution
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
response = NodeStartStreamResponse(
|
response = NodeStartStreamResponse(
|
||||||
|
@ -430,7 +430,7 @@ class WorkflowCycleManage:
|
||||||
:param workflow_node_execution: workflow node execution
|
:param workflow_node_execution: workflow node execution
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return NodeFinishStreamResponse(
|
return NodeFinishStreamResponse(
|
||||||
|
|
|
@ -29,7 +29,7 @@ class DatasetIndexToolCallbackHandler:
|
||||||
source="app",
|
source="app",
|
||||||
source_app_id=self._app_id,
|
source_app_id=self._app_id,
|
||||||
created_by_role=(
|
created_by_role=(
|
||||||
"account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
|
"account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
||||||
),
|
),
|
||||||
created_by=self._user_id,
|
created_by=self._user_id,
|
||||||
)
|
)
|
||||||
|
|
|
@ -292,7 +292,7 @@ class IndexingRunner:
|
||||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
# load file
|
# load file
|
||||||
if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
|
if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
data_source_info = dataset_document.data_source_info_dict
|
data_source_info = dataset_document.data_source_info_dict
|
||||||
|
|
|
@ -52,7 +52,7 @@ class TokenBufferMemory:
|
||||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||||
if files:
|
if files:
|
||||||
file_extra_config = None
|
file_extra_config = None
|
||||||
if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||||
else:
|
else:
|
||||||
if message.workflow_run_id:
|
if message.workflow_run_id:
|
||||||
|
|
|
@ -27,17 +27,17 @@ class ModelType(Enum):
|
||||||
|
|
||||||
:return: model type
|
:return: model type
|
||||||
"""
|
"""
|
||||||
if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value:
|
if origin_model_type in {"text-generation", cls.LLM.value}:
|
||||||
return cls.LLM
|
return cls.LLM
|
||||||
elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value:
|
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
|
||||||
return cls.TEXT_EMBEDDING
|
return cls.TEXT_EMBEDDING
|
||||||
elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value:
|
elif origin_model_type in {"reranking", cls.RERANK.value}:
|
||||||
return cls.RERANK
|
return cls.RERANK
|
||||||
elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value:
|
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
|
||||||
return cls.SPEECH2TEXT
|
return cls.SPEECH2TEXT
|
||||||
elif origin_model_type == "tts" or origin_model_type == cls.TTS.value:
|
elif origin_model_type in {"tts", cls.TTS.value}:
|
||||||
return cls.TTS
|
return cls.TTS
|
||||||
elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value:
|
elif origin_model_type in {"text2img", cls.TEXT2IMG.value}:
|
||||||
return cls.TEXT2IMG
|
return cls.TEXT2IMG
|
||||||
elif origin_model_type == cls.MODERATION.value:
|
elif origin_model_type == cls.MODERATION.value:
|
||||||
return cls.MODERATION
|
return cls.MODERATION
|
||||||
|
|
|
@ -494,7 +494,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||||
mime_type = data_split[0].replace("data:", "")
|
mime_type = data_split[0].replace("data:", "")
|
||||||
base64_data = data_split[1]
|
base64_data = data_split[1]
|
||||||
|
|
||||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported image type {mime_type}, "
|
f"Unsupported image type {mime_type}, "
|
||||||
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
||||||
|
|
|
@ -85,14 +85,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||||
for i in range(len(sentences))
|
for i in range(len(sentences))
|
||||||
]
|
]
|
||||||
for future in futures:
|
for future in futures:
|
||||||
yield from future.result().__enter__().iter_bytes(1024)
|
yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = client.audio.speech.with_streaming_response.create(
|
response = client.audio.speech.with_streaming_response.create(
|
||||||
model=model, voice=voice, response_format="mp3", input=content_text.strip()
|
model=model, voice=voice, response_format="mp3", input=content_text.strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
yield from response.__enter__().iter_bytes(1024)
|
yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise InvokeBadRequestError(str(ex))
|
raise InvokeBadRequestError(str(ex))
|
||||||
|
|
||||||
|
|
|
@ -454,7 +454,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
base64_data = data_split[1]
|
base64_data = data_split[1]
|
||||||
image_content = base64.b64decode(base64_data)
|
image_content = base64.b64decode(base64_data)
|
||||||
|
|
||||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported image type {mime_type}, "
|
f"Unsupported image type {mime_type}, "
|
||||||
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
||||||
|
@ -886,16 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
if error_code == "AccessDeniedException":
|
if error_code == "AccessDeniedException":
|
||||||
return InvokeAuthorizationError(error_msg)
|
return InvokeAuthorizationError(error_msg)
|
||||||
elif error_code in ["ResourceNotFoundException", "ValidationException"]:
|
elif error_code in {"ResourceNotFoundException", "ValidationException"}:
|
||||||
return InvokeBadRequestError(error_msg)
|
return InvokeBadRequestError(error_msg)
|
||||||
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
|
elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
|
||||||
return InvokeRateLimitError(error_msg)
|
return InvokeRateLimitError(error_msg)
|
||||||
elif error_code in [
|
elif error_code in {
|
||||||
"ModelTimeoutException",
|
"ModelTimeoutException",
|
||||||
"ModelErrorException",
|
"ModelErrorException",
|
||||||
"InternalServerException",
|
"InternalServerException",
|
||||||
"ModelNotReadyException",
|
"ModelNotReadyException",
|
||||||
]:
|
}:
|
||||||
return InvokeServerUnavailableError(error_msg)
|
return InvokeServerUnavailableError(error_msg)
|
||||||
elif error_code == "ModelStreamErrorException":
|
elif error_code == "ModelStreamErrorException":
|
||||||
return InvokeConnectionError(error_msg)
|
return InvokeConnectionError(error_msg)
|
||||||
|
|
|
@ -186,16 +186,16 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||||
|
|
||||||
if error_code == "AccessDeniedException":
|
if error_code == "AccessDeniedException":
|
||||||
return InvokeAuthorizationError(error_msg)
|
return InvokeAuthorizationError(error_msg)
|
||||||
elif error_code in ["ResourceNotFoundException", "ValidationException"]:
|
elif error_code in {"ResourceNotFoundException", "ValidationException"}:
|
||||||
return InvokeBadRequestError(error_msg)
|
return InvokeBadRequestError(error_msg)
|
||||||
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
|
elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
|
||||||
return InvokeRateLimitError(error_msg)
|
return InvokeRateLimitError(error_msg)
|
||||||
elif error_code in [
|
elif error_code in {
|
||||||
"ModelTimeoutException",
|
"ModelTimeoutException",
|
||||||
"ModelErrorException",
|
"ModelErrorException",
|
||||||
"InternalServerException",
|
"InternalServerException",
|
||||||
"ModelNotReadyException",
|
"ModelNotReadyException",
|
||||||
]:
|
}:
|
||||||
return InvokeServerUnavailableError(error_msg)
|
return InvokeServerUnavailableError(error_msg)
|
||||||
elif error_code == "ModelStreamErrorException":
|
elif error_code == "ModelStreamErrorException":
|
||||||
return InvokeConnectionError(error_msg)
|
return InvokeConnectionError(error_msg)
|
||||||
|
|
|
@ -6,10 +6,10 @@ from collections.abc import Generator
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
import google.ai.generativelanguage as glm
|
import google.ai.generativelanguage as glm
|
||||||
import google.api_core.exceptions as exceptions
|
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import google.generativeai.client as client
|
|
||||||
import requests
|
import requests
|
||||||
|
from google.api_core import exceptions
|
||||||
|
from google.generativeai import client
|
||||||
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
|
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
|
||||||
from google.generativeai.types.content_types import to_part
|
from google.generativeai.types.content_types import to_part
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
|
@ -77,7 +77,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||||
if "huggingfacehub_api_type" not in credentials:
|
if "huggingfacehub_api_type" not in credentials:
|
||||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
|
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
|
||||||
|
|
||||||
if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"):
|
if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}:
|
||||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
|
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
|
||||||
|
|
||||||
if "huggingfacehub_api_token" not in credentials:
|
if "huggingfacehub_api_token" not in credentials:
|
||||||
|
@ -94,7 +94,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||||
credentials["huggingfacehub_api_token"], model
|
credentials["huggingfacehub_api_token"], model
|
||||||
)
|
)
|
||||||
|
|
||||||
if credentials["task_type"] not in ("text2text-generation", "text-generation"):
|
if credentials["task_type"] not in {"text2text-generation", "text-generation"}:
|
||||||
raise CredentialsValidateFailedError(
|
raise CredentialsValidateFailedError(
|
||||||
"Huggingface Hub Task Type must be one of text2text-generation, text-generation."
|
"Huggingface Hub Task Type must be one of text2text-generation, text-generation."
|
||||||
)
|
)
|
||||||
|
|
|
@ -75,7 +75,7 @@ class TeiHelper:
|
||||||
if len(model_type.keys()) < 1:
|
if len(model_type.keys()) < 1:
|
||||||
raise RuntimeError("model_type is empty")
|
raise RuntimeError("model_type is empty")
|
||||||
model_type = list(model_type.keys())[0]
|
model_type = list(model_type.keys())[0]
|
||||||
if model_type not in ["embedding", "reranker"]:
|
if model_type not in {"embedding", "reranker"}:
|
||||||
raise RuntimeError(f"invalid model_type: {model_type}")
|
raise RuntimeError(f"invalid model_type: {model_type}")
|
||||||
|
|
||||||
max_input_length = response_json.get("max_input_length", 512)
|
max_input_length = response_json.get("max_input_length", 512)
|
||||||
|
|
|
@ -100,9 +100,9 @@ class MinimaxChatCompletion:
|
||||||
return self._handle_chat_generate_response(response)
|
return self._handle_chat_generate_response(response)
|
||||||
|
|
||||||
def _handle_error(self, code: int, msg: str):
|
def _handle_error(self, code: int, msg: str):
|
||||||
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
if code in {1000, 1001, 1013, 1027}:
|
||||||
raise InternalServerError(msg)
|
raise InternalServerError(msg)
|
||||||
elif code == 1002 or code == 1039:
|
elif code in {1002, 1039}:
|
||||||
raise RateLimitReachedError(msg)
|
raise RateLimitReachedError(msg)
|
||||||
elif code == 1004:
|
elif code == 1004:
|
||||||
raise InvalidAuthenticationError(msg)
|
raise InvalidAuthenticationError(msg)
|
||||||
|
|
|
@ -105,9 +105,9 @@ class MinimaxChatCompletionPro:
|
||||||
return self._handle_chat_generate_response(response)
|
return self._handle_chat_generate_response(response)
|
||||||
|
|
||||||
def _handle_error(self, code: int, msg: str):
|
def _handle_error(self, code: int, msg: str):
|
||||||
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
if code in {1000, 1001, 1013, 1027}:
|
||||||
raise InternalServerError(msg)
|
raise InternalServerError(msg)
|
||||||
elif code == 1002 or code == 1039:
|
elif code in {1002, 1039}:
|
||||||
raise RateLimitReachedError(msg)
|
raise RateLimitReachedError(msg)
|
||||||
elif code == 1004:
|
elif code == 1004:
|
||||||
raise InvalidAuthenticationError(msg)
|
raise InvalidAuthenticationError(msg)
|
||||||
|
|
|
@ -114,7 +114,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||||
raise CredentialsValidateFailedError("Invalid api key")
|
raise CredentialsValidateFailedError("Invalid api key")
|
||||||
|
|
||||||
def _handle_error(self, code: int, msg: str):
|
def _handle_error(self, code: int, msg: str):
|
||||||
if code == 1000 or code == 1001:
|
if code in {1000, 1001}:
|
||||||
raise InternalServerError(msg)
|
raise InternalServerError(msg)
|
||||||
elif code == 1002:
|
elif code == 1002:
|
||||||
raise RateLimitReachedError(msg)
|
raise RateLimitReachedError(msg)
|
||||||
|
|
|
@ -125,7 +125,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||||
model_mode = self.get_model_mode(base_model, credentials)
|
model_mode = self.get_model_mode(base_model, credentials)
|
||||||
|
|
||||||
# transform response format
|
# transform response format
|
||||||
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
|
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
|
||||||
stop = stop or []
|
stop = stop or []
|
||||||
if model_mode == LLMMode.CHAT:
|
if model_mode == LLMMode.CHAT:
|
||||||
# chat model
|
# chat model
|
||||||
|
|
|
@ -89,14 +89,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||||
for i in range(len(sentences))
|
for i in range(len(sentences))
|
||||||
]
|
]
|
||||||
for future in futures:
|
for future in futures:
|
||||||
yield from future.result().__enter__().iter_bytes(1024)
|
yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = client.audio.speech.with_streaming_response.create(
|
response = client.audio.speech.with_streaming_response.create(
|
||||||
model=model, voice=voice, response_format="mp3", input=content_text.strip()
|
model=model, voice=voice, response_format="mp3", input=content_text.strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
yield from response.__enter__().iter_bytes(1024)
|
yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise InvokeBadRequestError(str(ex))
|
raise InvokeBadRequestError(str(ex))
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||||
credentials["endpoint_url"] = "https://openrouter.ai/api/v1"
|
credentials["endpoint_url"] = "https://openrouter.ai/api/v1"
|
||||||
credentials["mode"] = self.get_model_mode(model).value
|
credentials["mode"] = self.get_model_mode(model).value
|
||||||
credentials["function_calling_type"] = "tool_call"
|
credentials["function_calling_type"] = "tool_call"
|
||||||
return
|
|
||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -154,7 +154,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
for key, value in input_properties:
|
for key, value in input_properties:
|
||||||
if key not in ["system_prompt", "prompt"] and "stop" not in key:
|
if key not in {"system_prompt", "prompt"} and "stop" not in key:
|
||||||
value_type = value.get("type")
|
value_type = value.get("type")
|
||||||
|
|
||||||
if not value_type:
|
if not value_type:
|
||||||
|
|
|
@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
for input_property in input_properties:
|
for input_property in input_properties:
|
||||||
if input_property[0] in ("text", "texts", "inputs"):
|
if input_property[0] in {"text", "texts", "inputs"}:
|
||||||
text_input_key = input_property[0]
|
text_input_key = input_property[0]
|
||||||
return text_input_key
|
return text_input_key
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||||
def _generate_embeddings_by_text_input_key(
|
def _generate_embeddings_by_text_input_key(
|
||||||
client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str]
|
client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str]
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
if text_input_key in ("text", "inputs"):
|
if text_input_key in {"text", "inputs"}:
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
result = client.run(replicate_model_version, input={text_input_key: text})
|
result = client.run(replicate_model_version, input={text_input_key: text})
|
||||||
|
|
|
@ -89,7 +89,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||||
:param tools: tools for tool calling
|
:param tools: tools for tool calling
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if model in ["qwen-turbo-chat", "qwen-plus-chat"]:
|
if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
|
||||||
model = model.replace("-chat", "")
|
model = model.replace("-chat", "")
|
||||||
if model == "farui-plus":
|
if model == "farui-plus":
|
||||||
model = "qwen-farui-plus"
|
model = "qwen-farui-plus"
|
||||||
|
@ -157,7 +157,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
mode = self.get_model_mode(model, credentials)
|
mode = self.get_model_mode(model, credentials)
|
||||||
|
|
||||||
if model in ["qwen-turbo-chat", "qwen-plus-chat"]:
|
if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
|
||||||
model = model.replace("-chat", "")
|
model = model.replace("-chat", "")
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
|
@ -201,7 +201,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: llm response
|
:return: llm response
|
||||||
"""
|
"""
|
||||||
if response.status_code != 200 and response.status_code != HTTPStatus.OK:
|
if response.status_code not in {200, HTTPStatus.OK}:
|
||||||
raise ServiceUnavailableError(response.message)
|
raise ServiceUnavailableError(response.message)
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
@ -240,7 +240,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||||
full_text = ""
|
full_text = ""
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for index, response in enumerate(responses):
|
for index, response in enumerate(responses):
|
||||||
if response.status_code != 200 and response.status_code != HTTPStatus.OK:
|
if response.status_code not in {200, HTTPStatus.OK}:
|
||||||
raise ServiceUnavailableError(
|
raise ServiceUnavailableError(
|
||||||
f"Failed to invoke model {model}, status code: {response.status_code}, "
|
f"Failed to invoke model {model}, status code: {response.status_code}, "
|
||||||
f"message: {response.message}"
|
f"message: {response.message}"
|
||||||
|
|
|
@ -93,7 +93,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
|
||||||
"""
|
"""
|
||||||
Code block mode wrapper for invoking large language model
|
Code block mode wrapper for invoking large language model
|
||||||
"""
|
"""
|
||||||
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
|
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
|
||||||
stop = stop or []
|
stop = stop or []
|
||||||
self._transform_chat_json_prompts(
|
self._transform_chat_json_prompts(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -5,7 +5,6 @@ import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
import google.api_core.exceptions as exceptions
|
|
||||||
import google.auth.transport.requests
|
import google.auth.transport.requests
|
||||||
import vertexai.generative_models as glm
|
import vertexai.generative_models as glm
|
||||||
from anthropic import AnthropicVertex, Stream
|
from anthropic import AnthropicVertex, Stream
|
||||||
|
@ -17,6 +16,7 @@ from anthropic.types import (
|
||||||
MessageStopEvent,
|
MessageStopEvent,
|
||||||
MessageStreamEvent,
|
MessageStreamEvent,
|
||||||
)
|
)
|
||||||
|
from google.api_core import exceptions
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
from google.oauth2 import service_account
|
from google.oauth2 import service_account
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -346,7 +346,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||||
mime_type = data_split[0].replace("data:", "")
|
mime_type = data_split[0].replace("data:", "")
|
||||||
base64_data = data_split[1]
|
base64_data = data_split[1]
|
||||||
|
|
||||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported image type {mime_type}, "
|
f"Unsupported image type {mime_type}, "
|
||||||
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
||||||
|
|
|
@ -96,7 +96,6 @@ class Signer:
|
||||||
signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
|
signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
|
||||||
sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
|
sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
|
||||||
request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials)
|
request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials)
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hashed_canonical_request_v4(request, meta):
|
def hashed_canonical_request_v4(request, meta):
|
||||||
|
@ -105,7 +104,7 @@ class Signer:
|
||||||
|
|
||||||
signed_headers = {}
|
signed_headers = {}
|
||||||
for key in request.headers:
|
for key in request.headers:
|
||||||
if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"):
|
if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"):
|
||||||
signed_headers[key.lower()] = request.headers[key]
|
signed_headers[key.lower()] = request.headers[key]
|
||||||
|
|
||||||
if "host" in signed_headers:
|
if "host" in signed_headers:
|
||||||
|
|
|
@ -69,7 +69,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||||
"""
|
"""
|
||||||
Code block mode wrapper for invoking large language model
|
Code block mode wrapper for invoking large language model
|
||||||
"""
|
"""
|
||||||
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
|
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
|
||||||
response_format = model_parameters["response_format"]
|
response_format = model_parameters["response_format"]
|
||||||
stop = stop or []
|
stop = stop or []
|
||||||
self._transform_json_prompts(
|
self._transform_json_prompts(
|
||||||
|
|
|
@ -103,7 +103,7 @@ class XinferenceHelper:
|
||||||
model_handle_type = "embedding"
|
model_handle_type = "embedding"
|
||||||
elif response_json.get("model_type") == "audio":
|
elif response_json.get("model_type") == "audio":
|
||||||
model_handle_type = "audio"
|
model_handle_type = "audio"
|
||||||
if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]:
|
if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}:
|
||||||
model_ability.append("text-to-audio")
|
model_ability.append("text-to-audio")
|
||||||
else:
|
else:
|
||||||
model_ability.append("audio-to-text")
|
model_ability.append("audio-to-text")
|
||||||
|
|
|
@ -186,10 +186,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||||
new_prompt_messages: list[PromptMessage] = []
|
new_prompt_messages: list[PromptMessage] = []
|
||||||
for prompt_message in prompt_messages:
|
for prompt_message in prompt_messages:
|
||||||
copy_prompt_message = prompt_message.copy()
|
copy_prompt_message = prompt_message.copy()
|
||||||
if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
|
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
|
||||||
if isinstance(copy_prompt_message.content, list):
|
if isinstance(copy_prompt_message.content, list):
|
||||||
# check if model is 'glm-4v'
|
# check if model is 'glm-4v'
|
||||||
if model not in ("glm-4v", "glm-4v-plus"):
|
if model not in {"glm-4v", "glm-4v-plus"}:
|
||||||
# not support list message
|
# not support list message
|
||||||
continue
|
continue
|
||||||
# get image and
|
# get image and
|
||||||
|
@ -209,10 +209,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||||
):
|
):
|
||||||
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
||||||
else:
|
else:
|
||||||
if (
|
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}:
|
||||||
copy_prompt_message.role == PromptMessageRole.USER
|
|
||||||
or copy_prompt_message.role == PromptMessageRole.TOOL
|
|
||||||
):
|
|
||||||
new_prompt_messages.append(copy_prompt_message)
|
new_prompt_messages.append(copy_prompt_message)
|
||||||
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
||||||
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
||||||
|
@ -226,7 +223,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||||
else:
|
else:
|
||||||
new_prompt_messages.append(copy_prompt_message)
|
new_prompt_messages.append(copy_prompt_message)
|
||||||
|
|
||||||
if model == "glm-4v" or model == "glm-4v-plus":
|
if model in {"glm-4v", "glm-4v-plus"}:
|
||||||
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
|
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
|
||||||
else:
|
else:
|
||||||
params = {"model": model, "messages": [], **model_parameters}
|
params = {"model": model, "messages": [], **model_parameters}
|
||||||
|
@ -270,11 +267,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||||
# chatglm model
|
# chatglm model
|
||||||
for prompt_message in new_prompt_messages:
|
for prompt_message in new_prompt_messages:
|
||||||
# merge system message to user message
|
# merge system message to user message
|
||||||
if (
|
if prompt_message.role in {
|
||||||
prompt_message.role == PromptMessageRole.SYSTEM
|
PromptMessageRole.SYSTEM,
|
||||||
or prompt_message.role == PromptMessageRole.TOOL
|
PromptMessageRole.TOOL,
|
||||||
or prompt_message.role == PromptMessageRole.USER
|
PromptMessageRole.USER,
|
||||||
):
|
}:
|
||||||
if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
|
if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
|
||||||
params["messages"][-1]["content"] += "\n\n" + prompt_message.content
|
params["messages"][-1]["content"] += "\n\n" + prompt_message.content
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .fine_tuning_job import FineTuningJob as FineTuningJob
|
from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob
|
||||||
from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
|
from .fine_tuning_job_event import FineTuningJobEvent
|
||||||
from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
|
|
||||||
|
|
|
@ -75,7 +75,7 @@ class CommonValidator:
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise ValueError(f"Variable {credential_form_schema.variable} should be string")
|
raise ValueError(f"Variable {credential_form_schema.variable} should be string")
|
||||||
|
|
||||||
if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]:
|
if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}:
|
||||||
# If the value is in options, no validation is performed
|
# If the value is in options, no validation is performed
|
||||||
if credential_form_schema.options:
|
if credential_form_schema.options:
|
||||||
if value not in [option.value for option in credential_form_schema.options]:
|
if value not in [option.value for option in credential_form_schema.options]:
|
||||||
|
@ -83,7 +83,7 @@ class CommonValidator:
|
||||||
|
|
||||||
if credential_form_schema.type == FormType.SWITCH:
|
if credential_form_schema.type == FormType.SWITCH:
|
||||||
# If the value is not in ['true', 'false'], an exception is thrown
|
# If the value is not in ['true', 'false'], an exception is thrown
|
||||||
if value.lower() not in ["true", "false"]:
|
if value.lower() not in {"true", "false"}:
|
||||||
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
|
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
|
||||||
|
|
||||||
value = True if value.lower() == "true" else False
|
value = True if value.lower() == "true" else False
|
||||||
|
|
|
@ -51,7 +51,7 @@ class ElasticSearchVector(BaseVector):
|
||||||
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
||||||
try:
|
try:
|
||||||
parsed_url = urlparse(config.host)
|
parsed_url = urlparse(config.host)
|
||||||
if parsed_url.scheme in ["http", "https"]:
|
if parsed_url.scheme in {"http", "https"}:
|
||||||
hosts = f"{config.host}:{config.port}"
|
hosts = f"{config.host}:{config.port}"
|
||||||
else:
|
else:
|
||||||
hosts = f"http://{config.host}:{config.port}"
|
hosts = f"http://{config.host}:{config.port}"
|
||||||
|
@ -94,7 +94,7 @@ class ElasticSearchVector(BaseVector):
|
||||||
return uuids
|
return uuids
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
return self._client.exists(index=self._collection_name, id=id).__bool__()
|
return bool(self._client.exists(index=self._collection_name, id=id))
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
for id in ids:
|
for id in ids:
|
||||||
|
|
|
@ -35,7 +35,7 @@ class MyScaleVector(BaseVector):
|
||||||
super().__init__(collection_name)
|
super().__init__(collection_name)
|
||||||
self._config = config
|
self._config = config
|
||||||
self._metric = metric
|
self._metric = metric
|
||||||
self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC
|
self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC
|
||||||
self._client = get_client(
|
self._client = get_client(
|
||||||
host=config.host,
|
host=config.host,
|
||||||
port=config.port,
|
port=config.port,
|
||||||
|
@ -92,7 +92,7 @@ class MyScaleVector(BaseVector):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def escape_str(value: Any) -> str:
|
def escape_str(value: Any) -> str:
|
||||||
return "".join(" " if c in ("\\", "'") else c for c in str(value))
|
return "".join(" " if c in {"\\", "'"} else c for c in str(value))
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
|
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
|
||||||
|
|
|
@ -223,15 +223,7 @@ class OracleVector(BaseVector):
|
||||||
words = pseg.cut(query)
|
words = pseg.cut(query)
|
||||||
current_entity = ""
|
current_entity = ""
|
||||||
for word, pos in words:
|
for word, pos in words:
|
||||||
if (
|
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
|
||||||
pos == "nr"
|
|
||||||
or pos == "Ng"
|
|
||||||
or pos == "eng"
|
|
||||||
or pos == "nz"
|
|
||||||
or pos == "n"
|
|
||||||
or pos == "ORG"
|
|
||||||
or pos == "v"
|
|
||||||
): # nr: 人名, ns: 地名, nt: 机构名
|
|
||||||
current_entity += word
|
current_entity += word
|
||||||
else:
|
else:
|
||||||
if current_entity:
|
if current_entity:
|
||||||
|
|
|
@ -98,17 +98,17 @@ class ExtractProcessor:
|
||||||
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
|
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
|
||||||
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
|
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
|
||||||
if etl_type == "Unstructured":
|
if etl_type == "Unstructured":
|
||||||
if file_extension == ".xlsx" or file_extension == ".xls":
|
if file_extension in {".xlsx", ".xls"}:
|
||||||
extractor = ExcelExtractor(file_path)
|
extractor = ExcelExtractor(file_path)
|
||||||
elif file_extension == ".pdf":
|
elif file_extension == ".pdf":
|
||||||
extractor = PdfExtractor(file_path)
|
extractor = PdfExtractor(file_path)
|
||||||
elif file_extension in [".md", ".markdown"]:
|
elif file_extension in {".md", ".markdown"}:
|
||||||
extractor = (
|
extractor = (
|
||||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
|
UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
|
||||||
if is_automatic
|
if is_automatic
|
||||||
else MarkdownExtractor(file_path, autodetect_encoding=True)
|
else MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||||
)
|
)
|
||||||
elif file_extension in [".htm", ".html"]:
|
elif file_extension in {".htm", ".html"}:
|
||||||
extractor = HtmlExtractor(file_path)
|
extractor = HtmlExtractor(file_path)
|
||||||
elif file_extension == ".docx":
|
elif file_extension == ".docx":
|
||||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||||
|
@ -134,13 +134,13 @@ class ExtractProcessor:
|
||||||
else TextExtractor(file_path, autodetect_encoding=True)
|
else TextExtractor(file_path, autodetect_encoding=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if file_extension == ".xlsx" or file_extension == ".xls":
|
if file_extension in {".xlsx", ".xls"}:
|
||||||
extractor = ExcelExtractor(file_path)
|
extractor = ExcelExtractor(file_path)
|
||||||
elif file_extension == ".pdf":
|
elif file_extension == ".pdf":
|
||||||
extractor = PdfExtractor(file_path)
|
extractor = PdfExtractor(file_path)
|
||||||
elif file_extension in [".md", ".markdown"]:
|
elif file_extension in {".md", ".markdown"}:
|
||||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||||
elif file_extension in [".htm", ".html"]:
|
elif file_extension in {".htm", ".html"}:
|
||||||
extractor = HtmlExtractor(file_path)
|
extractor = HtmlExtractor(file_path)
|
||||||
elif file_extension == ".docx":
|
elif file_extension == ".docx":
|
||||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||||
|
|
|
@ -32,7 +32,7 @@ class FirecrawlApp:
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
|
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
|
||||||
|
|
||||||
elif response.status_code in [402, 409, 500]:
|
elif response.status_code in {402, 409, 500}:
|
||||||
error_message = response.json().get("error", "Unknown error occurred")
|
error_message = response.json().get("error", "Unknown error occurred")
|
||||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
|
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -103,12 +103,12 @@ class NotionExtractor(BaseExtractor):
|
||||||
multi_select_list = property_value[type]
|
multi_select_list = property_value[type]
|
||||||
for multi_select in multi_select_list:
|
for multi_select in multi_select_list:
|
||||||
value.append(multi_select["name"])
|
value.append(multi_select["name"])
|
||||||
elif type == "rich_text" or type == "title":
|
elif type in {"rich_text", "title"}:
|
||||||
if len(property_value[type]) > 0:
|
if len(property_value[type]) > 0:
|
||||||
value = property_value[type][0]["plain_text"]
|
value = property_value[type][0]["plain_text"]
|
||||||
else:
|
else:
|
||||||
value = ""
|
value = ""
|
||||||
elif type == "select" or type == "status":
|
elif type in {"select", "status"}:
|
||||||
if property_value[type]:
|
if property_value[type]:
|
||||||
value = property_value[type]["name"]
|
value = property_value[type]["name"]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -115,7 +115,7 @@ class DatasetRetrieval:
|
||||||
|
|
||||||
available_datasets.append(dataset)
|
available_datasets.append(dataset)
|
||||||
all_documents = []
|
all_documents = []
|
||||||
user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
|
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
||||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||||
all_documents = self.single_retrieve(
|
all_documents = self.single_retrieve(
|
||||||
app_id,
|
app_id,
|
||||||
|
|
|
@ -35,7 +35,7 @@ def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> l
|
||||||
splits = re.split(separator, text)
|
splits = re.split(separator, text)
|
||||||
else:
|
else:
|
||||||
splits = list(text)
|
splits = list(text)
|
||||||
return [s for s in splits if (s != "" and s != "\n")]
|
return [s for s in splits if (s not in {"", "\n"})]
|
||||||
|
|
||||||
|
|
||||||
class TextSplitter(BaseDocumentTransformer, ABC):
|
class TextSplitter(BaseDocumentTransformer, ABC):
|
||||||
|
|
|
@ -68,7 +68,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||||
label = input_form[form_type]["label"]
|
label = input_form[form_type]["label"]
|
||||||
variable_name = input_form[form_type]["variable_name"]
|
variable_name = input_form[form_type]["variable_name"]
|
||||||
options = input_form[form_type].get("options", [])
|
options = input_form[form_type].get("options", [])
|
||||||
if form_type == "paragraph" or form_type == "text-input":
|
if form_type in {"paragraph", "text-input"}:
|
||||||
tool["parameters"].append(
|
tool["parameters"].append(
|
||||||
ToolParameter(
|
ToolParameter(
|
||||||
name=variable_name,
|
name=variable_name,
|
||||||
|
|
|
@ -168,7 +168,7 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||||
pass
|
pass
|
||||||
elif event == "close":
|
elif event == "close":
|
||||||
break
|
break
|
||||||
elif event == "error" or event == "filter":
|
elif event in {"error", "filter"}:
|
||||||
raise Exception(f"Failed to generate outline: {data}")
|
raise Exception(f"Failed to generate outline: {data}")
|
||||||
|
|
||||||
return outline
|
return outline
|
||||||
|
@ -213,7 +213,7 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||||
pass
|
pass
|
||||||
elif event == "close":
|
elif event == "close":
|
||||||
break
|
break
|
||||||
elif event == "error" or event == "filter":
|
elif event in {"error", "filter"}:
|
||||||
raise Exception(f"Failed to generate content: {data}")
|
raise Exception(f"Failed to generate content: {data}")
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
|
@ -39,11 +39,11 @@ class DallE3Tool(BuiltinTool):
|
||||||
n = tool_parameters.get("n", 1)
|
n = tool_parameters.get("n", 1)
|
||||||
# get quality
|
# get quality
|
||||||
quality = tool_parameters.get("quality", "standard")
|
quality = tool_parameters.get("quality", "standard")
|
||||||
if quality not in ["standard", "hd"]:
|
if quality not in {"standard", "hd"}:
|
||||||
return self.create_text_message("Invalid quality")
|
return self.create_text_message("Invalid quality")
|
||||||
# get style
|
# get style
|
||||||
style = tool_parameters.get("style", "vivid")
|
style = tool_parameters.get("style", "vivid")
|
||||||
if style not in ["natural", "vivid"]:
|
if style not in {"natural", "vivid"}:
|
||||||
return self.create_text_message("Invalid style")
|
return self.create_text_message("Invalid style")
|
||||||
# set extra body
|
# set extra body
|
||||||
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
||||||
|
|
|
@ -14,7 +14,7 @@ class SimpleCode(BuiltinTool):
|
||||||
language = tool_parameters.get("language", CodeLanguage.PYTHON3)
|
language = tool_parameters.get("language", CodeLanguage.PYTHON3)
|
||||||
code = tool_parameters.get("code", "")
|
code = tool_parameters.get("code", "")
|
||||||
|
|
||||||
if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]:
|
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
|
||||||
raise ValueError(f"Only python3 and javascript are supported, not {language}")
|
raise ValueError(f"Only python3 and javascript are supported, not {language}")
|
||||||
|
|
||||||
result = CodeExecutor.execute_code(language, "", code)
|
result = CodeExecutor.execute_code(language, "", code)
|
||||||
|
|
|
@ -34,11 +34,11 @@ class CogView3Tool(BuiltinTool):
|
||||||
n = tool_parameters.get("n", 1)
|
n = tool_parameters.get("n", 1)
|
||||||
# get quality
|
# get quality
|
||||||
quality = tool_parameters.get("quality", "standard")
|
quality = tool_parameters.get("quality", "standard")
|
||||||
if quality not in ["standard", "hd"]:
|
if quality not in {"standard", "hd"}:
|
||||||
return self.create_text_message("Invalid quality")
|
return self.create_text_message("Invalid quality")
|
||||||
# get style
|
# get style
|
||||||
style = tool_parameters.get("style", "vivid")
|
style = tool_parameters.get("style", "vivid")
|
||||||
if style not in ["natural", "vivid"]:
|
if style not in {"natural", "vivid"}:
|
||||||
return self.create_text_message("Invalid style")
|
return self.create_text_message("Invalid style")
|
||||||
# set extra body
|
# set extra body
|
||||||
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
||||||
|
|
|
@ -49,11 +49,11 @@ class DallE3Tool(BuiltinTool):
|
||||||
n = tool_parameters.get("n", 1)
|
n = tool_parameters.get("n", 1)
|
||||||
# get quality
|
# get quality
|
||||||
quality = tool_parameters.get("quality", "standard")
|
quality = tool_parameters.get("quality", "standard")
|
||||||
if quality not in ["standard", "hd"]:
|
if quality not in {"standard", "hd"}:
|
||||||
return self.create_text_message("Invalid quality")
|
return self.create_text_message("Invalid quality")
|
||||||
# get style
|
# get style
|
||||||
style = tool_parameters.get("style", "vivid")
|
style = tool_parameters.get("style", "vivid")
|
||||||
if style not in ["natural", "vivid"]:
|
if style not in {"natural", "vivid"}:
|
||||||
return self.create_text_message("Invalid style")
|
return self.create_text_message("Invalid style")
|
||||||
|
|
||||||
# call openapi dalle3
|
# call openapi dalle3
|
||||||
|
|
|
@ -133,9 +133,9 @@ class GetWorksheetFieldsTool(BuiltinTool):
|
||||||
|
|
||||||
def _extract_options(self, control: dict) -> list:
|
def _extract_options(self, control: dict) -> list:
|
||||||
options = []
|
options = []
|
||||||
if control["type"] in [9, 10, 11]:
|
if control["type"] in {9, 10, 11}:
|
||||||
options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])])
|
options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])])
|
||||||
elif control["type"] in [28, 36]:
|
elif control["type"] in {28, 36}:
|
||||||
itemnames = control["advancedSetting"].get("itemnames")
|
itemnames = control["advancedSetting"].get("itemnames")
|
||||||
if itemnames and itemnames.startswith("[{"):
|
if itemnames and itemnames.startswith("[{"):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -183,11 +183,11 @@ class ListWorksheetRecordsTool(BuiltinTool):
|
||||||
type_id = field.get("typeId")
|
type_id = field.get("typeId")
|
||||||
if type_id == 10:
|
if type_id == 10:
|
||||||
value = value if isinstance(value, str) else "、".join(value)
|
value = value if isinstance(value, str) else "、".join(value)
|
||||||
elif type_id in [28, 36]:
|
elif type_id in {28, 36}:
|
||||||
value = field.get("options", {}).get(value, value)
|
value = field.get("options", {}).get(value, value)
|
||||||
elif type_id in [26, 27, 48, 14]:
|
elif type_id in {26, 27, 48, 14}:
|
||||||
value = self.process_value(value)
|
value = self.process_value(value)
|
||||||
elif type_id in [35, 29]:
|
elif type_id in {35, 29}:
|
||||||
value = self.parse_cascade_or_associated(field, value)
|
value = self.parse_cascade_or_associated(field, value)
|
||||||
elif type_id == 40:
|
elif type_id == 40:
|
||||||
value = self.parse_location(value)
|
value = self.parse_location(value)
|
||||||
|
|
|
@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool):
|
||||||
models_data=[],
|
models_data=[],
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params=params,
|
params=params,
|
||||||
recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"),
|
recursive=result_type not in {"first sd_name", "first name sd_name pair"},
|
||||||
)
|
)
|
||||||
|
|
||||||
result_str = ""
|
result_str = ""
|
||||||
|
|
|
@ -38,7 +38,7 @@ class SearchAPI:
|
||||||
return {
|
return {
|
||||||
"engine": "google",
|
"engine": "google",
|
||||||
"q": query,
|
"q": query,
|
||||||
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
|
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -38,7 +38,7 @@ class SearchAPI:
|
||||||
return {
|
return {
|
||||||
"engine": "google_jobs",
|
"engine": "google_jobs",
|
||||||
"q": query,
|
"q": query,
|
||||||
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
|
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -38,7 +38,7 @@ class SearchAPI:
|
||||||
return {
|
return {
|
||||||
"engine": "google_news",
|
"engine": "google_news",
|
||||||
"q": query,
|
"q": query,
|
||||||
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
|
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -38,7 +38,7 @@ class SearchAPI:
|
||||||
"engine": "youtube_transcripts",
|
"engine": "youtube_transcripts",
|
||||||
"video_id": video_id,
|
"video_id": video_id,
|
||||||
"lang": language or "en",
|
"lang": language or "en",
|
||||||
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
|
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -214,7 +214,7 @@ class Spider:
|
||||||
return requests.delete(url, headers=headers, stream=stream)
|
return requests.delete(url, headers=headers, stream=stream)
|
||||||
|
|
||||||
def _handle_error(self, response, action):
|
def _handle_error(self, response, action):
|
||||||
if response.status_code in [402, 409, 500]:
|
if response.status_code in {402, 409, 500}:
|
||||||
error_message = response.json().get("error", "Unknown error occurred")
|
error_message = response.json().get("error", "Unknown error occurred")
|
||||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
|
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -32,7 +32,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
|
||||||
|
|
||||||
model = tool_parameters.get("model", "core")
|
model = tool_parameters.get("model", "core")
|
||||||
|
|
||||||
if model in ["sd3", "sd3-turbo"]:
|
if model in {"sd3", "sd3-turbo"}:
|
||||||
payload["model"] = tool_parameters.get("model")
|
payload["model"] = tool_parameters.get("model")
|
||||||
|
|
||||||
if model != "sd3-turbo":
|
if model != "sd3-turbo":
|
||||||
|
|
|
@ -38,7 +38,7 @@ class VannaTool(BuiltinTool):
|
||||||
vn = VannaDefault(model=model, api_key=api_key)
|
vn = VannaDefault(model=model, api_key=api_key)
|
||||||
|
|
||||||
db_type = tool_parameters.get("db_type", "")
|
db_type = tool_parameters.get("db_type", "")
|
||||||
if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]:
|
if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:
|
||||||
if not db_name:
|
if not db_name:
|
||||||
return self.create_text_message("Please input database name")
|
return self.create_text_message("Please input database name")
|
||||||
if not username:
|
if not username:
|
||||||
|
|
|
@ -19,7 +19,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
|
||||||
|
|
||||||
class BuiltinToolProviderController(ToolProviderController):
|
class BuiltinToolProviderController(ToolProviderController):
|
||||||
def __init__(self, **data: Any) -> None:
|
def __init__(self, **data: Any) -> None:
|
||||||
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
|
if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -153,10 +153,10 @@ class ToolProviderController(BaseModel, ABC):
|
||||||
|
|
||||||
# check type
|
# check type
|
||||||
credential_schema = credentials_need_to_validate[credential_name]
|
credential_schema = credentials_need_to_validate[credential_name]
|
||||||
if (
|
if credential_schema in {
|
||||||
credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT
|
ToolProviderCredentials.CredentialsType.SECRET_INPUT,
|
||||||
or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT
|
ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||||
):
|
}:
|
||||||
if not isinstance(credentials[credential_name], str):
|
if not isinstance(credentials[credential_name], str):
|
||||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||||
|
|
||||||
|
@ -184,11 +184,11 @@ class ToolProviderController(BaseModel, ABC):
|
||||||
if credential_schema.default is not None:
|
if credential_schema.default is not None:
|
||||||
default_value = credential_schema.default
|
default_value = credential_schema.default
|
||||||
# parse default value into the correct type
|
# parse default value into the correct type
|
||||||
if (
|
if credential_schema.type in {
|
||||||
credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT
|
ToolProviderCredentials.CredentialsType.SECRET_INPUT,
|
||||||
or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT
|
ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||||
or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT
|
ToolProviderCredentials.CredentialsType.SELECT,
|
||||||
):
|
}:
|
||||||
default_value = str(default_value)
|
default_value = str(default_value)
|
||||||
|
|
||||||
credentials[credential_name] = default_value
|
credentials[credential_name] = default_value
|
||||||
|
|
|
@ -5,7 +5,7 @@ from urllib.parse import urlencode
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import core.helper.ssrf_proxy as ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||||
|
@ -191,7 +191,7 @@ class ApiTool(Tool):
|
||||||
else:
|
else:
|
||||||
body = body
|
body = body
|
||||||
|
|
||||||
if method in ("get", "head", "post", "put", "delete", "patch"):
|
if method in {"get", "head", "post", "put", "delete", "patch"}:
|
||||||
response = getattr(ssrf_proxy, method)(
|
response = getattr(ssrf_proxy, method)(
|
||||||
url,
|
url,
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -224,9 +224,9 @@ class ApiTool(Tool):
|
||||||
elif option["type"] == "string":
|
elif option["type"] == "string":
|
||||||
return str(value)
|
return str(value)
|
||||||
elif option["type"] == "boolean":
|
elif option["type"] == "boolean":
|
||||||
if str(value).lower() in ["true", "1"]:
|
if str(value).lower() in {"true", "1"}:
|
||||||
return True
|
return True
|
||||||
elif str(value).lower() in ["false", "0"]:
|
elif str(value).lower() in {"false", "0"}:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
continue # Not a boolean, try next option
|
continue # Not a boolean, try next option
|
||||||
|
|
|
@ -189,10 +189,7 @@ class ToolEngine:
|
||||||
result += response.message
|
result += response.message
|
||||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||||
result += f"result link: {response.message}. please tell user to check it."
|
result += f"result link: {response.message}. please tell user to check it."
|
||||||
elif (
|
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||||
response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
|
|
||||||
or response.type == ToolInvokeMessage.MessageType.IMAGE
|
|
||||||
):
|
|
||||||
result += (
|
result += (
|
||||||
"image has been created and sent to user already, you do not need to create it,"
|
"image has been created and sent to user already, you do not need to create it,"
|
||||||
" just tell the user to check it now."
|
" just tell the user to check it now."
|
||||||
|
@ -212,10 +209,7 @@ class ToolEngine:
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
for response in tool_response:
|
for response in tool_response:
|
||||||
if (
|
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||||
response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
|
|
||||||
or response.type == ToolInvokeMessage.MessageType.IMAGE
|
|
||||||
):
|
|
||||||
mimetype = None
|
mimetype = None
|
||||||
if response.meta.get("mime_type"):
|
if response.meta.get("mime_type"):
|
||||||
mimetype = response.meta.get("mime_type")
|
mimetype = response.meta.get("mime_type")
|
||||||
|
@ -297,7 +291,7 @@ class ToolEngine:
|
||||||
belongs_to="assistant",
|
belongs_to="assistant",
|
||||||
url=message.url,
|
url=message.url,
|
||||||
upload_file_id=None,
|
upload_file_id=None,
|
||||||
created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"),
|
created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"),
|
||||||
created_by=user_id,
|
created_by=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ class ToolFileMessageTransformer:
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK:
|
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
||||||
result.append(message)
|
result.append(message)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||||
# try to download image
|
# try to download image
|
||||||
|
|
|
@ -165,7 +165,7 @@ class ApiBasedToolSchemaParser:
|
||||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||||
typ = parameter["schema"]["type"]
|
typ = parameter["schema"]["type"]
|
||||||
|
|
||||||
if typ == "integer" or typ == "number":
|
if typ in {"integer", "number"}:
|
||||||
return ToolParameter.ToolParameterType.NUMBER
|
return ToolParameter.ToolParameterType.NUMBER
|
||||||
elif typ == "boolean":
|
elif typ == "boolean":
|
||||||
return ToolParameter.ToolParameterType.BOOLEAN
|
return ToolParameter.ToolParameterType.BOOLEAN
|
||||||
|
|
|
@ -313,7 +313,7 @@ def normalize_whitespace(text):
|
||||||
|
|
||||||
|
|
||||||
def is_leaf(element):
|
def is_leaf(element):
|
||||||
return element.name in ["p", "li"]
|
return element.name in {"p", "li"}
|
||||||
|
|
||||||
|
|
||||||
def is_text(element):
|
def is_text(element):
|
||||||
|
|
|
@ -51,7 +51,7 @@ class RouteNodeState(BaseModel):
|
||||||
|
|
||||||
:param run_result: run result
|
:param run_result: run result
|
||||||
"""
|
"""
|
||||||
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
|
if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}:
|
||||||
raise Exception(f"Route state {self.id} already finished")
|
raise Exception(f"Route state {self.id} already finished")
|
||||||
|
|
||||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
|
|
|
@ -148,11 +148,11 @@ class AnswerStreamGeneratorRouter:
|
||||||
for edge in reverse_edges:
|
for edge in reverse_edges:
|
||||||
source_node_id = edge.source_node_id
|
source_node_id = edge.source_node_id
|
||||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||||
if source_node_type in (
|
if source_node_type in {
|
||||||
NodeType.ANSWER.value,
|
NodeType.ANSWER.value,
|
||||||
NodeType.IF_ELSE.value,
|
NodeType.IF_ELSE.value,
|
||||||
NodeType.QUESTION_CLASSIFIER.value,
|
NodeType.QUESTION_CLASSIFIER.value,
|
||||||
):
|
}:
|
||||||
answer_dependencies[answer_node_id].append(source_node_id)
|
answer_dependencies[answer_node_id].append(source_node_id)
|
||||||
else:
|
else:
|
||||||
cls._recursive_fetch_answer_dependencies(
|
cls._recursive_fetch_answer_dependencies(
|
||||||
|
|
|
@ -136,10 +136,10 @@ class EndStreamGeneratorRouter:
|
||||||
for edge in reverse_edges:
|
for edge in reverse_edges:
|
||||||
source_node_id = edge.source_node_id
|
source_node_id = edge.source_node_id
|
||||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||||
if source_node_type in (
|
if source_node_type in {
|
||||||
NodeType.IF_ELSE.value,
|
NodeType.IF_ELSE.value,
|
||||||
NodeType.QUESTION_CLASSIFIER,
|
NodeType.QUESTION_CLASSIFIER,
|
||||||
):
|
}:
|
||||||
end_dependencies[end_node_id].append(source_node_id)
|
end_dependencies[end_node_id].append(source_node_id)
|
||||||
else:
|
else:
|
||||||
cls._recursive_fetch_end_dependencies(
|
cls._recursive_fetch_end_dependencies(
|
||||||
|
|
|
@ -6,8 +6,8 @@ from urllib.parse import urlencode
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import core.helper.ssrf_proxy as ssrf_proxy
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
from core.workflow.entities.variable_entities import VariableSelector
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.nodes.http_request.entities import (
|
from core.workflow.nodes.http_request.entities import (
|
||||||
|
@ -176,7 +176,7 @@ class HttpExecutor:
|
||||||
elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
|
elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
|
||||||
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
|
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||||
|
|
||||||
if node_data.body.type in ["form-data", "x-www-form-urlencoded"]:
|
if node_data.body.type in {"form-data", "x-www-form-urlencoded"}:
|
||||||
body = self._to_dict(body_data)
|
body = self._to_dict(body_data)
|
||||||
|
|
||||||
if node_data.body.type == "form-data":
|
if node_data.body.type == "form-data":
|
||||||
|
@ -187,7 +187,7 @@ class HttpExecutor:
|
||||||
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
|
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
|
||||||
else:
|
else:
|
||||||
self.body = urlencode(body)
|
self.body = urlencode(body)
|
||||||
elif node_data.body.type in ["json", "raw-text"]:
|
elif node_data.body.type in {"json", "raw-text"}:
|
||||||
self.body = body_data
|
self.body = body_data
|
||||||
elif node_data.body.type == "none":
|
elif node_data.body.type == "none":
|
||||||
self.body = ""
|
self.body = ""
|
||||||
|
@ -258,7 +258,7 @@ class HttpExecutor:
|
||||||
"follow_redirects": True,
|
"follow_redirects": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.method in ("get", "head", "post", "put", "delete", "patch"):
|
if self.method in {"get", "head", "post", "put", "delete", "patch"}:
|
||||||
response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
|
response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid http method {self.method}")
|
raise ValueError(f"Invalid http method {self.method}")
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user