diff --git a/api/.env.example b/api/.env.example index db20c95903..22097ad2a0 100644 --- a/api/.env.example +++ b/api/.env.example @@ -233,6 +233,8 @@ VIKINGDB_SOCKET_TIMEOUT=30 UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 +UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 +UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 # Model Configuration MULTIMODAL_SEND_IMAGE_FORMAT=base64 @@ -310,6 +312,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000 WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 +MAX_VARIABLE_SIZE=204800 # App configuration APP_MAX_EXECUTION_TIME=1200 diff --git a/api/.vscode/launch.json.example b/api/.vscode/launch.json.example index e9f8e42dd5..b9e32e2511 100644 --- a/api/.vscode/launch.json.example +++ b/api/.vscode/launch.json.example @@ -1,8 +1,15 @@ { "version": "0.2.0", + "compounds": [ + { + "name": "Launch Flask and Celery", + "configurations": ["Python: Flask", "Python: Celery"] + } + ], "configurations": [ { "name": "Python: Flask", + "consoleName": "Flask", "type": "debugpy", "request": "launch", "python": "${workspaceFolder}/.venv/bin/python", @@ -17,12 +24,12 @@ }, "args": [ "run", - "--host=0.0.0.0", "--port=5001" ] }, { "name": "Python: Celery", + "consoleName": "Celery", "type": "debugpy", "request": "launch", "python": "${workspaceFolder}/.venv/bin/python", @@ -45,10 +52,10 @@ "-c", "1", "--loglevel", - "info", + "DEBUG", "-Q", "dataset,generation,mail,ops_trace,app_deletion" ] - }, + } ] -} \ No newline at end of file +} diff --git a/api/commands.py b/api/commands.py index 5b7f79c8f0..f2809be8e7 100644 --- a/api/commands.py +++ b/api/commands.py @@ -19,7 +19,7 @@ from extensions.ext_redis import redis_client from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair -from models.account import Tenant +from models import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation @@ -426,14 +426,14 @@ def convert_to_agent_apps(): # fetch first 1000 apps sql_query = """SELECT a.id AS id FROM apps a INNER JOIN app_model_configs am ON a.app_model_config_id=am.id - WHERE a.mode = 'chat' - AND am.agent_mode is not null + WHERE a.mode = 'chat' + AND am.agent_mode is not null AND ( - am.agent_mode like '%"strategy": "function_call"%' + am.agent_mode like '%"strategy": "function_call"%' OR am.agent_mode like '%"strategy": "react"%' - ) + ) AND ( - am.agent_mode like '{"enabled": true%' + am.agent_mode like '{"enabled": true%' OR am.agent_mode like '{"max_iteration": %' ) ORDER BY a.created_at DESC LIMIT 1000 """ diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 8a24392ea2..2e4a09518b 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -20,11 +20,11 @@ class SecurityConfig(BaseSettings): Security-related configurations for the application """ - SECRET_KEY: Optional[str] = Field( + SECRET_KEY: str = Field( description="Secret key for secure session cookie signing." "Make sure you are changing this key for your deployment with a strong key." "Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.", - default=None, + default="", ) RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field( @@ -186,6 +186,16 @@ class FileUploadConfig(BaseSettings): default=10, ) + UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="video file size limit in Megabytes for uploading files", + default=100, + ) + + UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="audio file size limit in Megabytes for uploading files", + default=50, + ) + BATCH_UPLOAD_LIMIT: NonNegativeInt = Field( description="Maximum number of files allowed in a batch upload operation", default=20, @@ -364,8 +374,8 @@ class WorkflowConfig(BaseSettings): ) MAX_VARIABLE_SIZE: PositiveInt = Field( - description="Maximum size in bytes for a single variable in workflows. Default to 5KB.", - default=5 * 1024, + description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.", + default=200 * 1024, ) @@ -493,6 +503,7 @@ class RagEtlConfig(BaseSettings): Configuration for RAG ETL processes """ + # TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config ETL_TYPE: str = Field( description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'", default="dify", @@ -559,7 +570,7 @@ class IndexingConfig(BaseSettings): class ImageFormatConfig(BaseSettings): - MULTIMODAL_SEND_IMAGE_FORMAT: str = Field( + MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field( description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64", default="base64", ) diff --git a/api/constants/__init__.py b/api/constants/__init__.py index 75eaf81638..66b9c0b632 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1,2 +1,22 @@ +from configs import dify_config + HIDDEN_VALUE = "[__HIDDEN__]" UUID_NIL = "00000000-0000-0000-0000-000000000000" + +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] +IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) + +VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"] +VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS]) + +AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"] +AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) + + +if dify_config.ETL_TYPE == "Unstructured": + DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"] + DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub")) + DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) +else: + DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] + DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 623a1a28eb..85380b7330 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -1,7 +1,9 @@ from contextvars import ContextVar +from typing import TYPE_CHECKING -from core.workflow.entities.variable_pool import VariablePool +if TYPE_CHECKING: + from core.workflow.entities.variable_pool import VariablePool tenant_id: ContextVar[str] = ContextVar("tenant_id") -workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool") +workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index c1e16b3b9b..b60a424d98 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -22,7 +22,8 @@ from fields.conversation_fields import ( ) from libs.helper import DatetimeString from libs.login import login_required -from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation +from models import Conversation, EndUser, Message, MessageAnnotation +from models.model import AppMode class CompletionConversationApi(Resource): diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 26da1ef26d..115a832da9 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -12,7 +12,7 @@ from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.login import login_required -from models.model import Site +from models import Site def parse_app_site_args(): diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 0a693b84e2..a8f601aeee 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -13,14 +13,14 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.segments import factory -from core.errors.error import AppInvokeQuotaExceededError +from factories import variable_factory from fields.workflow_fields import workflow_fields from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required -from models.model import App, AppMode +from models import App +from models.model import AppMode from services.app_dsl_service import AppDslService from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError @@ -101,9 +101,13 @@ class DraftWorkflowApi(Resource): try: environment_variables_list = args.get("environment_variables") or [] - environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + environment_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + ] conversation_variables_list = args.get("conversation_variables") or [] - conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] + conversation_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + ] workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=args["graph"], @@ -273,17 +277,15 @@ class DraftWorkflowRunApi(Resource): parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() - try: - response = AppGenerateService.generate( - app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True - ) + response = AppGenerateService.generate( + app_model=app_model, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) - return helper.compact_generate_response(response) - except (ValueError, AppInvokeQuotaExceededError) as e: - raise e - except Exception as e: - logging.exception("internal server error.") - raise InternalServerError() + return helper.compact_generate_response(response) class WorkflowTaskStopApi(Resource): diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index dc962409cc..629b7a8bf4 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -7,7 +7,8 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs.login import login_required -from models.model import App, AppMode +from models import App +from models.model import AppMode from services.workflow_app_service import WorkflowAppService diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index a055d03deb..5824ead9c3 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -13,7 +13,8 @@ from fields.workflow_run_fields import ( ) from libs.helper import uuid_value from libs.login import login_required -from models.model import App, AppMode +from models import App +from models.model import AppMode from services.workflow_run_service import WorkflowRunService diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index c7e54f2be0..f46af0f1ca 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -13,8 +13,8 @@ from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.helper import DatetimeString from libs.login import login_required +from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode -from models.workflow import WorkflowRunTriggeredFrom class WorkflowDailyRunsStatistic(Resource): diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 5e0a4bc814..c71ee8e5df 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,7 +5,8 @@ from typing import Optional, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models.model import App, AppMode +from models import App +from models.model import AppMode def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index ba0e07cd16..282e69448e 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -13,7 +13,8 @@ from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import extract_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo -from models.account import Account, AccountStatus +from models import Account +from models.account import AccountStatus from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountNotFoundError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 0e1acab946..a2c9760782 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -15,8 +15,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from libs.login import login_required -from models.dataset import Document -from models.source import DataSourceOauthBinding +from models import DataSourceOauthBinding, Document from services.dataset_service import DatasetService, DocumentService from tasks.document_indexing_sync_task import document_indexing_sync_task diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 6583356d23..16a77ed880 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -24,8 +24,8 @@ from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.document_fields import document_status_fields from libs.login import login_required -from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment -from models.model import ApiToken, UploadFile +from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile +from models.dataset import DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 829ef11e52..cdabac491e 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -46,8 +46,7 @@ from fields.document_fields import ( document_with_segments_fields, ) from libs.login import login_required -from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment -from models.model import UploadFile +from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from services.dataset_service import DatasetService, DocumentService from tasks.add_document_to_index_task import add_document_to_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2405649387..08ea414288 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -24,7 +24,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import segment_fields from libs.login import login_required -from models.dataset import DocumentSegment +from models import DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index 846aa70e86..5ed9a61545 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -1,9 +1,12 @@ +import urllib.parse + from flask import request from flask_login import current_user from flask_restful import Resource, marshal_with import services from configs import dify_config +from constants import DOCUMENT_EXTENSIONS from controllers.console import api from controllers.console.datasets.error import ( FileTooLargeError, @@ -13,9 +16,10 @@ from controllers.console.datasets.error import ( ) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from fields.file_fields import file_fields, upload_config_fields +from core.helper import ssrf_proxy +from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields from libs.login import login_required -from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService +from services.file_service import FileService PREVIEW_WORDS_LIMIT = 3000 @@ -51,7 +55,7 @@ class FileApi(Resource): if len(request.files) > 1: raise TooManyFilesError() try: - upload_file = FileService.upload_file(file, current_user) + upload_file = FileService.upload_file(file=file, user=current_user) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: @@ -75,11 +79,24 @@ class FileSupportTypeApi(Resource): @login_required @account_initialization_required def get(self): - etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS - return {"allowed_extensions": allowed_extensions} + return {"allowed_extensions": DOCUMENT_EXTENSIONS} + + +class RemoteFileInfoApi(Resource): + @marshal_with(remote_file_info_fields) + def get(self, url): + decoded_url = urllib.parse.unquote(url) + try: + response = ssrf_proxy.head(decoded_url) + return { + "file_type": response.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(response.headers.get("Content-Length", 0)), + } + except Exception as e: + return {"error": str(e)}, 400 api.add_resource(FileApi, "/files/upload") api.add_resource(FilePreviewApi, "/files//preview") api.add_resource(FileSupportTypeApi, "/files/support-type") +api.add_resource(RemoteFileInfoApi, "/remote-files/") diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 408afc33a0..d72715a38c 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -11,7 +11,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields from libs.login import login_required -from models.model import App, InstalledApp, RecommendedApp +from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index a7ccf737a8..0fc9637479 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -18,7 +18,7 @@ message_fields = { "inputs": fields.Raw, "query": fields.String, "answer": fields.String, - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "created_at": TimestampField, } diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 3c9317847b..49ea81a8a0 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.login import login_required -from models.model import InstalledApp +from models import InstalledApp def installed_app_required(view=None): diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index dec426128f..97f5625726 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -20,7 +20,7 @@ from extensions.ext_database import db from fields.member_fields import account_fields from libs.helper import TimestampField, timezone from libs.login import login_required -from models.account import AccountIntegrate, InvitationCode +from models import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index d2a17b133b..aaa24d501c 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -360,16 +360,15 @@ class ToolWorkflowProviderCreateApi(Resource): args = reqparser.parse_args() return WorkflowToolManageService.create_workflow_tool( - user_id, - tenant_id, - args["workflow_app_id"], - args["name"], - args["label"], - args["icon"], - args["description"], - args["parameters"], - args["privacy_policy"], - args.get("labels", []), + user_id=user_id, + tenant_id=tenant_id, + workflow_app_id=args["workflow_app_id"], + name=args["name"], + label=args["label"], + icon=args["icon"], + description=args["description"], + parameters=args["parameters"], + privacy_policy=args["privacy_policy"], ) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index af3ebc099b..96f866fca2 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -198,7 +198,7 @@ class WebappLogoWorkspaceApi(Resource): raise UnsupportedFileTypeError() try: - upload_file = FileService.upload_file(file, current_user, True) + upload_file = FileService.upload_file(file=file, user=current_user) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index a56c1c332d..4b2d61e7c3 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -10,6 +10,10 @@ from services.file_service import FileService class ImagePreviewApi(Resource): + """ + Deprecated + """ + def get(self, file_id): file_id = str(file_id) @@ -21,7 +25,36 @@ class ImagePreviewApi(Resource): return {"content": "Invalid request."}, 400 try: - generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign) + generator, mimetype = FileService.get_image_preview( + file_id=file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return Response(generator, mimetype=mimetype) + + +class FilePreviewApi(Resource): + def get(self, file_id): + file_id = str(file_id) + + timestamp = request.args.get("timestamp") + nonce = request.args.get("nonce") + sign = request.args.get("sign") + + if not timestamp or not nonce or not sign: + return {"content": "Invalid request."}, 400 + + try: + generator, mimetype = FileService.get_signed_file_preview( + file_id=file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() @@ -49,4 +82,5 @@ class WorkspaceWebappLogoApi(Resource): api.add_resource(ImagePreviewApi, "/files//image-preview") +api.add_resource(FilePreviewApi, "/files//file-preview") api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces//webapp-logo") diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 406cd42214..104b7cd9bb 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -16,6 +16,7 @@ class ToolFilePreviewApi(Resource): parser.add_argument("timestamp", type=str, required=True, location="args") parser.add_argument("nonce", type=str, required=True, location="args") parser.add_argument("sign", type=str, required=True, location="args") + parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") args = parser.parse_args() @@ -28,18 +29,27 @@ class ToolFilePreviewApi(Resource): raise Forbidden("Invalid request.") try: - result = ToolFileManager.get_file_generator_by_tool_file_id( + stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id( file_id, ) - if not result: + if not stream or not tool_file: raise NotFound("file is not found") - - generator, mimetype = result except Exception: raise UnsupportedFileTypeError() - return Response(generator, mimetype=mimetype) + response = Response( + stream, + mimetype=tool_file.mimetype, + direct_passthrough=True, + headers={ + "Content-Length": str(tool_file.size), + }, + ) + if args["as_attachment"]: + response.headers["Content-Disposition"] = f"attachment; filename={tool_file.name}" + + return response api.add_resource(ToolFilePreviewApi, "/files/tools/.") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index a70ee89b5e..d9a9fad13c 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -48,7 +48,7 @@ class MessageListApi(Resource): "tool_input": fields.String, "created_at": TimestampField, "observation": fields.String, - "message_files": fields.List(fields.String, attribute="files"), + "message_files": fields.List(fields.String), } message_fields = { @@ -58,7 +58,7 @@ class MessageListApi(Resource): "inputs": fields.Raw, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "created_at": TimestampField, diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py index 253b1d511c..6b9c267003 100644 --- a/api/controllers/web/file.py +++ b/api/controllers/web/file.py @@ -1,3 +1,5 @@ +import urllib.parse + from flask import request from flask_restful import marshal_with @@ -5,7 +7,8 @@ import services from controllers.web import api from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError from controllers.web.wraps import WebApiResource -from fields.file_fields import file_fields +from core.helper import ssrf_proxy +from fields.file_fields import file_fields, remote_file_info_fields from services.file_service import FileService @@ -31,4 +34,19 @@ class FileApi(WebApiResource): return upload_file, 201 +class RemoteFileInfoApi(WebApiResource): + @marshal_with(remote_file_info_fields) + def get(self, url): + decoded_url = urllib.parse.unquote(url) + try: + response = ssrf_proxy.head(decoded_url) + return { + "file_type": response.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(response.headers.get("Content-Length", 0)), + } + except Exception as e: + return {"error": str(e)}, 400 + + api.add_resource(FileApi, "/files/upload") +api.add_resource(RemoteFileInfoApi, "/remote-files/") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 2d2a5866c8..98891f5d00 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields +from fields.raws import FilesContainedField from libs import helper from libs.helper import TimestampField, uuid_value from models.model import AppMode @@ -58,10 +59,10 @@ class MessageListApi(WebApiResource): "id": fields.String, "conversation_id": fields.String, "parent_message_id": fields.String, - "inputs": fields.Raw, + "inputs": FilesContainedField, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "created_at": TimestampField, diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 8253f5fc57..b0492e6b6f 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -17,7 +17,7 @@ message_fields = { "inputs": fields.Raw, "query": fields.String, "answer": fields.String, - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "created_at": TimestampField, } diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 5295f97bdb..514dcfbd68 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -16,13 +16,14 @@ from core.app.entities.app_invoke_entities import ( ) from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.file.message_file_parser import MessageFileParser +from core.file import file_manager from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + LLMUsage, PromptMessage, + PromptMessageContent, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, @@ -40,9 +41,9 @@ from core.tools.entities.tool_entities import ( from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool from core.tools.tool_manager import ToolManager -from core.tools.utils.tool_parameter_converter import ToolParameterConverter from extensions.ext_database import db -from models.model import Conversation, Message, MessageAgentThought +from factories import file_factory +from models.model import Conversation, Message, MessageAgentThought, MessageFile from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) @@ -66,23 +67,6 @@ class BaseAgentRunner(AppRunner): db_variables: Optional[ToolConversationVariables] = None, model_instance: ModelInstance = None, ) -> None: - """ - Agent runner - :param tenant_id: tenant id - :param application_generate_entity: application generate entity - :param conversation: conversation - :param app_config: app generate entity - :param model_config: model config - :param config: dataset config - :param queue_manager: queue manager - :param message: message - :param user_id: user id - :param memory: memory - :param prompt_messages: prompt messages - :param variables_pool: variables pool - :param db_variables: db variables - :param model_instance: model instance - """ self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity self.conversation = conversation @@ -180,7 +164,7 @@ class BaseAgentRunner(AppRunner): if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) + parameter_type = parameter.type.as_normal_type() enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] @@ -265,7 +249,7 @@ class BaseAgentRunner(AppRunner): if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) + parameter_type = parameter.type.as_normal_type() enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] @@ -511,26 +495,24 @@ class BaseAgentRunner(AppRunner): return result def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: - message_file_parser = MessageFileParser( - tenant_id=self.tenant_id, - app_id=self.app_config.app_id, - ) - - files = message.message_files + files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.transform_message_files(files, file_extra_config) + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=self.tenant_id, config=file_extra_config + ) else: file_objs = [] if not file_objs: return UserPromptMessage(content=message.query) else: - prompt_message_contents = [TextPromptMessageContent(data=message.query)] + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) for file_obj in file_objs: - prompt_message_contents.append(file_obj.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) return UserPromptMessage(content=prompt_message_contents) else: diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 5e16373fff..6261a9b12c 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,9 +1,11 @@ import json from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import ( +from core.file import file_manager +from core.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContent, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, @@ -32,9 +34,10 @@ class CotChatAgentRunner(CotAgentRunner): Organize user query """ if self.files: - prompt_message_contents = [TextPromptMessageContent(data=query)] + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=query)) for file_obj in self.files: - prompt_message_contents.append(file_obj.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 7b22025582..9083b4e85f 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -7,10 +7,15 @@ from typing import Any, Optional, Union from core.agent.base_agent_runner import BaseAgentRunner from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.file import file_manager +from core.model_runtime.entities import ( AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, PromptMessage, + PromptMessageContent, PromptMessageContentType, SystemPromptMessage, TextPromptMessageContent, @@ -390,9 +395,10 @@ class FunctionCallAgentRunner(BaseAgentRunner): Organize user query """ if self.files: - prompt_message_contents = [TextPromptMessageContent(data=query)] + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=query)) for file_obj in self.files: - prompt_message_contents.append(file_obj.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index a1bfde3208..126eb0b41e 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -53,12 +53,11 @@ class BasicVariablesConfigManager: VariableEntity( type=variable_type, variable=variable.get("variable"), - description=variable.get("description"), + description=variable.get("description", ""), label=variable.get("label"), required=variable.get("required", False), max_length=variable.get("max_length"), - options=variable.get("options"), - default=variable.get("default"), + options=variable.get("options", []), ) ) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 7e5899bafa..d8fa08c0a3 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,11 +1,12 @@ +from collections.abc import Sequence from enum import Enum from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field -from core.file.file_obj import FileExtraConfig +from core.file import FileExtraConfig, FileTransferMethod, FileType from core.model_runtime.entities.message_entities import PromptMessageRole -from models import AppMode +from models.model import AppMode class ModelConfigEntity(BaseModel): @@ -69,7 +70,7 @@ class PromptTemplateEntity(BaseModel): ADVANCED = "advanced" @classmethod - def value_of(cls, value: str) -> "PromptType": + def value_of(cls, value: str): """ Get value of given mode. @@ -93,6 +94,8 @@ class VariableEntityType(str, Enum): PARAGRAPH = "paragraph" NUMBER = "number" EXTERNAL_DATA_TOOL = "external_data_tool" + FILE = "file" + FILE_LIST = "file-list" class VariableEntity(BaseModel): @@ -102,13 +105,14 @@ class VariableEntity(BaseModel): variable: str label: str - description: Optional[str] = None + description: str = "" type: VariableEntityType required: bool = False max_length: Optional[int] = None - options: Optional[list[str]] = None - default: Optional[str] = None - hint: Optional[str] = None + options: Sequence[str] = Field(default_factory=list) + allowed_file_types: Sequence[FileType] = Field(default_factory=list) + allowed_file_extensions: Sequence[str] = Field(default_factory=list) + allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) class ExternalDataVariableEntity(BaseModel): @@ -136,7 +140,7 @@ class DatasetRetrieveConfigEntity(BaseModel): MULTIPLE = "multiple" @classmethod - def value_of(cls, value: str) -> "RetrieveStrategy": + def value_of(cls, value: str): """ Get value of given mode. diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 7a275cb532..6d301f6ea7 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,12 +1,13 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any -from core.file.file_obj import FileExtraConfig +from core.file.models import FileExtraConfig +from models import FileUploadConfig class FileUploadConfigManager: @classmethod - def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]: + def convert(cls, config: Mapping[str, Any], is_vision: bool = True): """ Convert model config to model config @@ -15,19 +16,18 @@ class FileUploadConfigManager: """ file_upload_dict = config.get("file_upload") if file_upload_dict: - if file_upload_dict.get("image"): - if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]: - image_config = { - "number_limits": file_upload_dict["image"]["number_limits"], - "transfer_methods": file_upload_dict["image"]["transfer_methods"], + if file_upload_dict.get("enabled"): + data = { + "image_config": { + "number_limits": file_upload_dict["number_limits"], + "transfer_methods": file_upload_dict["allowed_file_upload_methods"], } + } - if is_vision: - image_config["detail"] = file_upload_dict["image"]["detail"] + if is_vision: + data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low") - return FileExtraConfig(image_config=image_config) - - return None + return FileExtraConfig.model_validate(data) @classmethod def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]: @@ -39,29 +39,7 @@ class FileUploadConfigManager: """ if not config.get("file_upload"): config["file_upload"] = {} - - if not isinstance(config["file_upload"], dict): - raise ValueError("file_upload must be of dict type") - - # check image config - if not config["file_upload"].get("image"): - config["file_upload"]["image"] = {"enabled": False} - - if config["file_upload"]["image"]["enabled"]: - number_limits = config["file_upload"]["image"]["number_limits"] - if number_limits < 1 or number_limits > 6: - raise ValueError("number_limits must be in [1, 6]") - - if is_vision: - detail = config["file_upload"]["image"]["detail"] - if detail not in {"high", "low"}: - raise ValueError("detail must be in ['high', 'low']") - - transfer_methods = config["file_upload"]["image"]["transfer_methods"] - if not isinstance(transfer_methods, list): - raise ValueError("transfer_methods must be of list type") - for method in transfer_methods: - if method not in {"remote_url", "local_file"}: - raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") + else: + FileUploadConfig.model_validate(config["file_upload"]) return config, ["file_upload"] diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 4b117d87f8..2f1da38082 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager: # variables for variable in user_input_form: - variables.append(VariableEntity(**variable)) + variables.append(VariableEntity.model_validate(variable)) return variables diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 7fff925f4b..39ab87c914 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -21,11 +21,12 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db +from factories import file_factory from models.account import Account +from models.enums import CreatedByRole from models.model import App, Conversation, EndUser, Message from models.workflow import Workflow @@ -96,10 +97,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # parse files files = args["files"] if args.get("files") else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) else: file_objs = [] @@ -107,8 +114,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id - trace_manager = TraceQueueManager(app_model.id, user_id) + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) if invoke_from == InvokeFrom.DEBUGGER: # always enable retriever resource in debugger mode @@ -120,7 +128,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 1dcd051d15..65d744eddf 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,31 +1,27 @@ import logging -import os from collections.abc import Mapping from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session +from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, -) +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.queue_entities import ( QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent, ) from core.moderation.base import ModerationError -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import UserFrom +from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from models.enums import UserFrom from models.model import App, Conversation, EndUser, Message from models.workflow import ConversationVariable, WorkflowType @@ -44,12 +40,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): conversation: Conversation, message: Message, ) -> None: - """ - :param application_generate_entity: application generate entity - :param queue_manager: application queue manager - :param conversation: conversation - :param message: message - """ super().__init__(queue_manager) self.application_generate_entity = application_generate_entity @@ -57,10 +47,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self.message = message def run(self) -> None: - """ - Run application - :return: - """ app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) @@ -81,7 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id = self.application_generate_entity.user_id workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get("DEBUG", "False").lower() == "true"): + if dify_config.DEBUG: workflow_callbacks.append(WorkflowLoggingCallback()) if self.application_generate_entity.single_iteration_run: @@ -201,15 +187,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): query: str, message_id: str, ) -> bool: - """ - Handle input moderation - :param app_record: app record - :param app_generate_entity: application generate entity - :param inputs: inputs - :param query: query - :param message_id: message id - :return: - """ try: # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( @@ -229,14 +206,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): def handle_annotation_reply( self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity ) -> bool: - """ - Handle annotation reply - :param app_record: app record - :param message: message - :param query: query - :param app_generate_entity: application generate entity - """ - # annotation reply annotation_reply = self.query_app_annotations_to_reply( app_record=app_record, message=message, @@ -258,8 +227,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: """ Direct output - :param text: text - :return: """ self._publish_event(QueueTextChunkEvent(text=text)) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index fd63c7787f..e4cb3f8527 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -1,7 +1,7 @@ import json import logging import time -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Optional, Union from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -9,6 +9,7 @@ from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, + InvokeFrom, ) from core.app.entities.queue_entities import ( QueueAdvancedChatMessageEndEvent, @@ -50,10 +51,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes import NodeType from events.message_event import message_was_created from extensions.ext_database import db +from models import Conversation, EndUser, Message, MessageFile from models.account import Account -from models.model import Conversation, EndUser, Message +from models.enums import CreatedByRole from models.workflow import ( Workflow, WorkflowNodeExecution, @@ -120,6 +123,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._wip_workflow_node_executions = {} self._conversation_name_generate_thread = None + self._recorded_files: list[Mapping[str, Any]] = [] def process(self): """ @@ -298,6 +302,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc elif isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._handle_workflow_node_execution_success(event) + # Record files if it's an answer node or end node + if event.node_type in [NodeType.ANSWER, NodeType.END]: + self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) + response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, @@ -364,7 +372,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - outputs=json.dumps(event.outputs) if event.outputs else None, + outputs=event.outputs, conversation_id=self._conversation.id, trace_manager=trace_manager, ) @@ -490,10 +498,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._conversation_name_generate_thread.join() def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: - """ - Save message. - :return: - """ self._refetch_message() self._message.answer = self._task_state.answer @@ -501,6 +505,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) + message_files = [ + MessageFile( + message_id=self._message.id, + type=file["type"], + transfer_method=file["transfer_method"], + url=file["remote_url"], + belongs_to="assistant", + upload_file_id=file["related_id"], + created_by_role=CreatedByRole.ACCOUNT + if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatedByRole.END_USER, + created_by=self._message.from_account_id or self._message.from_end_user_id or "", + ) + for file in self._recorded_files + ] + db.session.add_all(message_files) if graph_runtime_state and graph_runtime_state.llm_usage: usage = graph_runtime_state.llm_usage @@ -540,7 +560,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras ) def _handle_output_moderation_chunk(self, text: str) -> bool: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 81c3c765dc..de12f5a441 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -18,12 +18,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser +from factories import file_factory +from models import Account, App, EndUser +from models.enums import CreatedByRole logger = logging.getLogger(__name__) @@ -50,7 +50,12 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): ) -> dict: ... def generate( - self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True + self, + app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True, ) -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -98,12 +103,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # always enable retriever resource in debugger mode override_model_config_dict["retriever_resource"] = {"enabled": True} + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER + # parse files - files = args["files"] if args.get("files") else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + files = args.get("files") or [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) else: file_objs = [] @@ -116,8 +128,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id - trace_manager = TraceQueueManager(app_model.id, user_id) + trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) # init application generate entity application_generate_entity = AgentChatAppGenerateEntity( @@ -125,7 +136,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 15be7000fc..2707ada6cb 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,35 +1,92 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional -from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType +from core.app.app_config.entities import VariableEntityType +from core.file import File, FileExtraConfig +from factories import file_factory + +if TYPE_CHECKING: + from core.app.app_config.entities import AppConfig, VariableEntity + from models.enums import CreatedByRole class BaseAppGenerator: - def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]: + def _prepare_user_inputs( + self, + *, + user_inputs: Optional[Mapping[str, Any]], + app_config: "AppConfig", + user_id: str, + role: "CreatedByRole", + ) -> Mapping[str, Any]: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values variables = app_config.variables - filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} - filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} - return filtered_inputs + user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} + user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} + # Convert files in inputs to File + entity_dictionary = {item.variable: item for item in app_config.variables} + # Convert single file to File + files_inputs = { + k: file_factory.build_from_mapping( + mapping=v, + tenant_id=app_config.tenant_id, + user_id=user_id, + role=role, + config=FileExtraConfig( + allowed_file_types=entity_dictionary[k].allowed_file_types, + allowed_extensions=entity_dictionary[k].allowed_file_extensions, + allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, + ), + ) + for k, v in user_inputs.items() + if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE + } + # Convert list of files to File + file_list_inputs = { + k: file_factory.build_from_mappings( + mappings=v, + tenant_id=app_config.tenant_id, + user_id=user_id, + role=role, + config=FileExtraConfig( + allowed_file_types=entity_dictionary[k].allowed_file_types, + allowed_extensions=entity_dictionary[k].allowed_file_extensions, + allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, + ), + ) + for k, v in user_inputs.items() + if isinstance(v, list) + # Ensure skip List + and all(isinstance(item, dict) for item in v) + and entity_dictionary[k].type == VariableEntityType.FILE_LIST + } + # Merge all inputs + user_inputs = {**user_inputs, **files_inputs, **file_list_inputs} - def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): - user_input_value = inputs.get(var.variable) - if var.required and not user_input_value: - raise ValueError(f"{var.variable} is required in input form") - if not var.required and not user_input_value: - # TODO: should we return None here if the default value is None? - return var.default or "" - if ( - var.type - in { - VariableEntityType.TEXT_INPUT, - VariableEntityType.SELECT, - VariableEntityType.PARAGRAPH, - } - and user_input_value - and not isinstance(user_input_value, str) + # Check if all files are converted to File + if any(filter(lambda v: isinstance(v, dict), user_inputs.values())): + raise ValueError("Invalid input type") + if any( + filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values())) ): + raise ValueError("Invalid input type") + + return user_inputs + + def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"): + user_input_value = inputs.get(var.variable) + if not user_input_value: + if var.required: + raise ValueError(f"{var.variable} is required in input form") + else: + return None + + if var.type in { + VariableEntityType.TEXT_INPUT, + VariableEntityType.SELECT, + VariableEntityType.PARAGRAPH, + } and not isinstance(user_input_value, str): raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): # may raise ValueError if user_input_value is not a valid number @@ -41,12 +98,24 @@ class BaseAppGenerator: except ValueError: raise ValueError(f"{var.variable} in input form must be a valid number") if var.type == VariableEntityType.SELECT: - options = var.options or [] + options = var.options if user_input_value not in options: raise ValueError(f"{var.variable} in input form must be one of the following: {options}") elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}: - if var.max_length and user_input_value and len(user_input_value) > var.max_length: + if var.max_length and len(user_input_value) > var.max_length: raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") + elif var.type == VariableEntityType.FILE: + if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File): + raise ValueError(f"{var.variable} in input form must be a file") + elif var.type == VariableEntityType.FILE_LIST: + if not ( + isinstance(user_input_value, list) + and ( + all(isinstance(item, dict) for item in user_input_value) + or all(isinstance(item, File) for item in user_input_value) + ) + ): + raise ValueError(f"{var.variable} in input form must be a list of files") return user_input_value diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 203aca3384..609fd03f22 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from models.model import App, AppMode, Message, MessageAnnotation if TYPE_CHECKING: - from core.file.file_obj import FileVar + from core.file.models import File class AppRunner: @@ -37,7 +37,7 @@ class AppRunner: model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list["FileVar"], + files: list["File"], query: Optional[str] = None, ) -> int: """ @@ -137,7 +137,7 @@ class AppRunner: model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list["FileVar"], + files: list["File"], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 49b56ecc67..5c074f5306 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -18,11 +18,12 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db +from factories import file_factory from models.account import Account +from models.enums import CreatedByRole from models.model import App, EndUser logger = logging.getLogger(__name__) @@ -100,12 +101,19 @@ class ChatAppGenerator(MessageBasedAppGenerator): # always enable retriever resource in debugger mode override_model_config_dict["retriever_resource"] = {"enabled": True} + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER + # parse files files = args["files"] if args.get("files") else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) else: file_objs = [] @@ -118,7 +126,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - trace_manager = TraceQueueManager(app_model.id) + trace_manager = TraceQueueManager(app_id=app_model.id) # init application generate entity application_generate_entity = ChatAppGenerateEntity( @@ -126,15 +134,17 @@ class ChatAppGenerator(MessageBasedAppGenerator): app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, - stream=stream, invoke_from=invoke_from, extras=extras, trace_manager=trace_manager, + stream=stream, ) # init generate records diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 7fce296f2b..46450d39c0 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser, Message +from factories import file_factory +from models import Account, App, EndUser, Message +from models.enums import CreatedByRole from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError @@ -88,12 +88,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): tenant_id=app_model.tenant_id, config=args.get("model_config") ) + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER + # parse files files = args["files"] if args.get("files") else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) else: file_objs = [] @@ -103,6 +110,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) # get tracing instance + user_id = user.id if isinstance(user, Account) else user.session_id trace_manager = TraceQueueManager(app_model.id) # init application generate entity @@ -110,7 +118,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), - inputs=self._get_cleaned_inputs(inputs, app_config), + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), query=query, files=file_objs, user_id=user.id, @@ -251,10 +259,16 @@ class CompletionAppGenerator(MessageBasedAppGenerator): override_model_config_dict["model"] = model_dict # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user) + file_objs = file_factory.build_from_mappings( + mappings=message.message_files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) else: file_objs = [] diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 65b759acf5..2b5597e055 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -26,7 +26,7 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db -from models.account import Account +from models import Account from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError @@ -235,13 +235,13 @@ class MessageBasedAppGenerator(BaseAppGenerator): for file in application_generate_entity.files: message_file = MessageFile( message_id=message.id, - type=file.type.value, - transfer_method=file.transfer_method.value, + type=file.type, + transfer_method=file.transfer_method, belongs_to="user", - url=file.url, + url=file.remote_url, upload_file_id=file.related_id, created_by_role=("account" if account_id else "end_user"), - created_by=account_id or end_user_id, + created_by=account_id or end_user_id or "", ) db.session.add(message_file) db.session.commit() diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index bd0ab53278..a865c8a68b 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -3,7 +3,7 @@ import logging import os import threading import uuid -from collections.abc import Generator +from collections.abc import Generator, Mapping, Sequence from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app @@ -20,13 +20,12 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser -from models.workflow import Workflow +from factories import file_factory +from models import Account, App, EndUser, Workflow +from models.enums import CreatedByRole logger = logging.getLogger(__name__) @@ -63,49 +62,46 @@ class WorkflowAppGenerator(BaseAppGenerator): app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: dict, + args: Mapping[str, Any], invoke_from: InvokeFrom, stream: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, ): - """ - Generate App response. + files: Sequence[Mapping[str, Any]] = args.get("files") or [] - :param app_model: App - :param workflow: Workflow - :param user: account or end user - :param args: request args - :param invoke_from: invoke from source - :param stream: is stream - :param call_depth: call depth - :param workflow_thread_pool_id: workflow thread pool id - """ - inputs = args["inputs"] + role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER # parse files - files = args["files"] if args.get("files") else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) - else: - file_objs = [] + system_files = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + user_id=user.id, + role=role, + config=file_extra_config, + ) # convert to app config - app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow, + ) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id - trace_manager = TraceQueueManager(app_model.id, user_id) + trace_manager = TraceQueueManager( + app_id=app_model.id, + user_id=user.id if isinstance(user, Account) else user.session_id, + ) + inputs: Mapping[str, Any] = args["inputs"] workflow_run_id = str(uuid.uuid4()) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, - inputs=self._get_cleaned_inputs(inputs, app_config), - files=file_objs, + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + files=system_files, user_id=user.id, stream=stream, invoke_from=invoke_from, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 378a4bb8bc..faefcb0ed5 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -1,21 +1,20 @@ import logging -import os from typing import Optional, cast +from configs import dify_config from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( InvokeFrom, WorkflowAppGenerateEntity, ) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import UserFrom +from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from models.enums import UserFrom from models.model import App, EndUser from models.workflow import WorkflowType @@ -71,7 +70,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): db.session.close() workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get("DEBUG", "False").lower() == "true"): + if dify_config.DEBUG: workflow_callbacks.append(WorkflowLoggingCallback()) # if only single iteration run is requested diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 7c53556e43..419a5da806 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,4 +1,3 @@ -import json import logging import time from collections.abc import Generator @@ -334,9 +333,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - outputs=json.dumps(event.outputs) - if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs - else None, + outputs=event.outputs, conversation_id=None, trace_manager=trace_manager, ) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index ce266116a7..ca23bbdd47 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -20,7 +20,6 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, @@ -41,9 +40,9 @@ from core.workflow.graph_engine.entities.event import ( ParallelBranchRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.iteration.entities import IterationNodeData -from core.workflow.nodes.node_mapping import node_classes +from core.workflow.nodes import NodeType +from core.workflow.nodes.iteration import IterationNodeData +from core.workflow.nodes.node_mapping import node_type_classes_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App @@ -137,9 +136,8 @@ class WorkflowBasedAppRunner(AppRunner): raise ValueError("iteration node id not found in workflow graph") # Get node class - node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type")) - node_cls = node_classes.get(node_type) - node_cls = cast(type[BaseNode], node_cls) + node_type = NodeType(iteration_node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping[node_type] # init variable pool variable_pool = VariablePool( diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 98685513a3..f2eba29323 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Optional @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from core.file.file_obj import FileVar +from core.file.models import File from core.model_runtime.entities.model_entities import AIModelEntity from core.ops.ops_trace_manager import TraceQueueManager @@ -23,7 +23,7 @@ class InvokeFrom(Enum): DEBUGGER = "debugger" @classmethod - def value_of(cls, value: str) -> "InvokeFrom": + def value_of(cls, value: str): """ Get value of given mode. @@ -82,7 +82,7 @@ class AppGenerateEntity(BaseModel): app_config: AppConfig inputs: Mapping[str, Any] - files: list[FileVar] = [] + files: Sequence[File] user_id: str # extras diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 4577e28535..bc43baf8a5 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -5,9 +5,10 @@ from typing import Any, Optional from pydantic import BaseModel, field_validator from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes import NodeType +from core.workflow.nodes.base import BaseNodeData class QueueEvent(str, Enum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 49e5f55ebc..4b5f4716ed 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Optional @@ -119,6 +120,7 @@ class MessageEndStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE_END id: str metadata: dict = {} + files: Optional[Sequence[Mapping[str, Any]]] = None class MessageFileStreamResponse(StreamResponse): @@ -211,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse): created_by: Optional[dict] = None created_at: int finished_at: int - files: Optional[list[dict]] = [] + files: Optional[Sequence[Mapping[str, Any]]] = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED workflow_run_id: str @@ -296,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse): execution_metadata: Optional[dict] = None created_at: int finished_at: int - files: Optional[list[dict]] = [] + files: Optional[Sequence[Mapping[str, Any]]] = [] parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None parent_parallel_id: Optional[str] = None diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py deleted file mode 100644 index 3c4d7046f4..0000000000 --- a/api/core/app/segments/parser.py +++ /dev/null @@ -1,18 +0,0 @@ -import re - -from core.workflow.entities.variable_pool import VariablePool - -from . import SegmentGroup, factory - -VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") - - -def convert_template(*, template: str, variable_pool: VariablePool): - parts = re.split(VARIABLE_PATTERN, template) - segments = [] - for part in filter(lambda x: x, parts): - if "." in part and (value := variable_pool.get(part.split("."))): - segments.append(value) - else: - segments.append(factory.build_segment(part)) - return SegmentGroup(value=segments) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index b8f5ac2603..138503d404 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,5 +1,6 @@ import json import time +from collections.abc import Mapping, Sequence from datetime import datetime, timezone from typing import Any, Optional, Union, cast @@ -27,27 +28,26 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, WorkflowTaskState, ) -from core.file.file_obj import FileVar +from core.file import FILE_MODEL_IDENTITY, File from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager -from core.workflow.entities.node_entities import NodeType from core.workflow.enums import SystemVariableKey +from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.account import Account +from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( - CreatedByRole, Workflow, WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, WorkflowRunStatus, - WorkflowRunTriggeredFrom, ) @@ -117,7 +117,7 @@ class WorkflowCycleManage: start_at: float, total_tokens: int, total_steps: int, - outputs: Optional[str] = None, + outputs: Mapping[str, Any] | None = None, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: @@ -133,8 +133,10 @@ class WorkflowCycleManage: """ workflow_run = self._refetch_workflow_run(workflow_run.id) + outputs = WorkflowEntry.handle_special_values(outputs) + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value - workflow_run.outputs = outputs + workflow_run.outputs = json.dumps(outputs or {}) workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps @@ -265,6 +267,7 @@ class WorkflowCycleManage: workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) inputs = WorkflowEntry.handle_special_values(event.inputs) + process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) execution_metadata = ( json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None @@ -276,7 +279,7 @@ class WorkflowCycleManage: { WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, + WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, WorkflowNodeExecution.execution_metadata: execution_metadata, WorkflowNodeExecution.finished_at: finished_at, @@ -286,10 +289,11 @@ class WorkflowCycleManage: db.session.commit() db.session.close() + process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None + workflow_node_execution.process_data = json.dumps(process_data) if process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.finished_at = finished_at @@ -308,6 +312,7 @@ class WorkflowCycleManage: workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) inputs = WorkflowEntry.handle_special_values(event.inputs) + process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) finished_at = datetime.now(timezone.utc).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() @@ -317,7 +322,7 @@ class WorkflowCycleManage: WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, WorkflowNodeExecution.error: event.error, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, + WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, WorkflowNodeExecution.finished_at: finished_at, WorkflowNodeExecution.elapsed_time: elapsed_time, @@ -326,11 +331,12 @@ class WorkflowCycleManage: db.session.commit() db.session.close() + process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = event.error workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None + workflow_node_execution.process_data = json.dumps(process_data) if process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time @@ -637,7 +643,7 @@ class WorkflowCycleManage: ), ) - def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: + def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -646,15 +652,15 @@ class WorkflowCycleManage: if not outputs_dict: return [] - files = [] - for output_var, output_value in outputs_dict.items(): - file_vars = self._fetch_files_from_variable_value(output_value) - if file_vars: - files.extend(file_vars) + files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] + # Remove None + files = [file for file in files if file] + # Flatten list + files = [file for sublist in files for file in sublist] return files - def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]: + def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: """ Fetch files from variable value :param value: variable value @@ -666,17 +672,17 @@ class WorkflowCycleManage: files = [] if isinstance(value, list): for item in value: - file_var = self._get_file_var_from_value(item) - if file_var: - files.append(file_var) + file = self._get_file_var_from_value(item) + if file: + files.append(file) elif isinstance(value, dict): - file_var = self._get_file_var_from_value(value) - if file_var: - files.append(file_var) + file = self._get_file_var_from_value(value) + if file: + files.append(file) return files - def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]: + def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None: """ Get file var from value :param value: variable value @@ -685,14 +691,11 @@ class WorkflowCycleManage: if not value: return None - if isinstance(value, dict): - if "__variant" in value and value["__variant"] == FileVar.__name__: - return value - elif isinstance(value, FileVar): + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + return value + elif isinstance(value, File): return value.to_dict() - return None - def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: """ Refetch workflow run diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py deleted file mode 100644 index 10bc9f6ed7..0000000000 --- a/api/core/entities/message_entities.py +++ /dev/null @@ -1,29 +0,0 @@ -import enum -from typing import Any - -from pydantic import BaseModel - - -class PromptMessageFileType(enum.Enum): - IMAGE = "image" - - @staticmethod - def value_of(value): - for member in PromptMessageFileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class PromptMessageFile(BaseModel): - type: PromptMessageFileType - data: Any = None - - -class ImagePromptMessageFile(PromptMessageFile): - class DETAIL(enum.Enum): - LOW = "low" - HIGH = "high" - - type: PromptMessageFileType = PromptMessageFileType.IMAGE - detail: DETAIL = DETAIL.LOW diff --git a/api/core/file/__init__.py b/api/core/file/__init__.py index e69de29bb2..bdaf8793fa 100644 --- a/api/core/file/__init__.py +++ b/api/core/file/__init__.py @@ -0,0 +1,19 @@ +from .constants import FILE_MODEL_IDENTITY +from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType +from .models import ( + File, + FileExtraConfig, + ImageConfig, +) + +__all__ = [ + "FileType", + "FileExtraConfig", + "FileTransferMethod", + "FileBelongsTo", + "File", + "ImageConfig", + "FileAttribute", + "ArrayFileAttribute", + "FILE_MODEL_IDENTITY", +] diff --git a/api/core/file/constants.py b/api/core/file/constants.py new file mode 100644 index 0000000000..ce1d238e93 --- /dev/null +++ b/api/core/file/constants.py @@ -0,0 +1 @@ +FILE_MODEL_IDENTITY = "__dify__file__" diff --git a/api/core/file/enums.py b/api/core/file/enums.py new file mode 100644 index 0000000000..f4153f1676 --- /dev/null +++ b/api/core/file/enums.py @@ -0,0 +1,55 @@ +from enum import Enum + + +class FileType(str, Enum): + IMAGE = "image" + DOCUMENT = "document" + AUDIO = "audio" + VIDEO = "video" + CUSTOM = "custom" + + @staticmethod + def value_of(value): + for member in FileType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileTransferMethod(str, Enum): + REMOTE_URL = "remote_url" + LOCAL_FILE = "local_file" + TOOL_FILE = "tool_file" + + @staticmethod + def value_of(value): + for member in FileTransferMethod: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileBelongsTo(str, Enum): + USER = "user" + ASSISTANT = "assistant" + + @staticmethod + def value_of(value): + for member in FileBelongsTo: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileAttribute(str, Enum): + TYPE = "type" + SIZE = "size" + NAME = "name" + MIME_TYPE = "mime_type" + TRANSFER_METHOD = "transfer_method" + URL = "url" + EXTENSION = "extension" + + +class ArrayFileAttribute(str, Enum): + LENGTH = "length" diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py new file mode 100644 index 0000000000..0c6ce8ce75 --- /dev/null +++ b/api/core/file/file_manager.py @@ -0,0 +1,156 @@ +import base64 + +from configs import dify_config +from core.file import file_repository +from core.helper import ssrf_proxy +from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent +from extensions.ext_database import db +from extensions.ext_storage import storage + +from . import helpers +from .enums import FileAttribute +from .models import File, FileTransferMethod, FileType +from .tool_file_parser import ToolFileParser + + +def get_attr(*, file: File, attr: FileAttribute): + match attr: + case FileAttribute.TYPE: + return file.type.value + case FileAttribute.SIZE: + return file.size + case FileAttribute.NAME: + return file.filename + case FileAttribute.MIME_TYPE: + return file.mime_type + case FileAttribute.TRANSFER_METHOD: + return file.transfer_method.value + case FileAttribute.URL: + return file.remote_url + case FileAttribute.EXTENSION: + return file.extension + case _: + raise ValueError(f"Invalid file attribute: {attr}") + + +def to_prompt_message_content(f: File, /): + """ + Convert a File object to an ImagePromptMessageContent object. + + This function takes a File object and converts it to an ImagePromptMessageContent + object, which can be used as a prompt for image-based AI models. + + Args: + file (File): The File object to convert. Must be of type FileType.IMAGE. + + Returns: + ImagePromptMessageContent: An object containing the image data and detail level. + + Raises: + ValueError: If the file is not an image or if the file data is missing. + + Note: + The detail level of the image prompt is determined by the file's extra_config. + If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW. + """ + match f.type: + case FileType.IMAGE: + if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": + data = _to_url(f) + else: + data = _to_base64_data_string(f) + + if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail: + detail = f._extra_config.image_config.detail + else: + detail = ImagePromptMessageContent.DETAIL.LOW + + return ImagePromptMessageContent(data=data, detail=detail) + case FileType.AUDIO: + encoded_string = _file_to_encoded_string(f) + if f.extension is None: + raise ValueError("Missing file extension") + return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) + case _: + raise ValueError(f"file type {f.type} is not supported") + + +def download(f: File, /): + upload_file = file_repository.get_upload_file(session=db.session(), file=f) + return _download_file_content(upload_file.key) + + +def _download_file_content(path: str, /): + """ + Download and return the contents of a file as bytes. + + This function loads the file from storage and ensures it's in bytes format. + + Args: + path (str): The path to the file in storage. + + Returns: + bytes: The contents of the file as a bytes object. + + Raises: + ValueError: If the loaded file is not a bytes object. + """ + data = storage.load(path, stream=False) + if not isinstance(data, bytes): + raise ValueError(f"file {path} is not a bytes object") + return data + + +def _get_encoded_string(f: File, /): + match f.transfer_method: + case FileTransferMethod.REMOTE_URL: + response = ssrf_proxy.get(f.remote_url) + response.raise_for_status() + content = response.content + encoded_string = base64.b64encode(content).decode("utf-8") + return encoded_string + case FileTransferMethod.LOCAL_FILE: + upload_file = file_repository.get_upload_file(session=db.session(), file=f) + data = _download_file_content(upload_file.key) + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string + case FileTransferMethod.TOOL_FILE: + tool_file = file_repository.get_tool_file(session=db.session(), file=f) + data = _download_file_content(tool_file.file_key) + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string + case _: + raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + + +def _to_base64_data_string(f: File, /): + encoded_string = _get_encoded_string(f) + return f"data:{f.mime_type};base64,{encoded_string}" + + +def _file_to_encoded_string(f: File, /): + match f.type: + case FileType.IMAGE: + return _to_base64_data_string(f) + case FileType.AUDIO: + return _get_encoded_string(f) + case _: + raise ValueError(f"file type {f.type} is not supported") + + +def _to_url(f: File, /): + if f.transfer_method == FileTransferMethod.REMOTE_URL: + if f.remote_url is None: + raise ValueError("Missing file remote_url") + return f.remote_url + elif f.transfer_method == FileTransferMethod.LOCAL_FILE: + if f.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=f.related_id) + elif f.transfer_method == FileTransferMethod.TOOL_FILE: + # add sign url + if f.related_id is None or f.extension is None: + raise ValueError("Missing file related_id or extension") + return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension) + else: + raise ValueError(f"Unsupported transfer method: {f.transfer_method}") diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py deleted file mode 100644 index 5c4e694025..0000000000 --- a/api/core/file/file_obj.py +++ /dev/null @@ -1,145 +0,0 @@ -import enum -from typing import Any, Optional - -from pydantic import BaseModel - -from core.file.tool_file_parser import ToolFileParser -from core.file.upload_file_parser import UploadFileParser -from core.model_runtime.entities.message_entities import ImagePromptMessageContent -from extensions.ext_database import db - - -class FileExtraConfig(BaseModel): - """ - File Upload Entity. - """ - - image_config: Optional[dict[str, Any]] = None - - -class FileType(enum.Enum): - IMAGE = "image" - - @staticmethod - def value_of(value): - for member in FileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileTransferMethod(enum.Enum): - REMOTE_URL = "remote_url" - LOCAL_FILE = "local_file" - TOOL_FILE = "tool_file" - - @staticmethod - def value_of(value): - for member in FileTransferMethod: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileBelongsTo(enum.Enum): - USER = "user" - ASSISTANT = "assistant" - - @staticmethod - def value_of(value): - for member in FileBelongsTo: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileVar(BaseModel): - id: Optional[str] = None # message file id - tenant_id: str - type: FileType - transfer_method: FileTransferMethod - url: Optional[str] = None # remote url - related_id: Optional[str] = None - extra_config: Optional[FileExtraConfig] = None - filename: Optional[str] = None - extension: Optional[str] = None - mime_type: Optional[str] = None - - def to_dict(self) -> dict: - return { - "__variant": self.__class__.__name__, - "tenant_id": self.tenant_id, - "type": self.type.value, - "transfer_method": self.transfer_method.value, - "url": self.preview_url, - "remote_url": self.url, - "related_id": self.related_id, - "filename": self.filename, - "extension": self.extension, - "mime_type": self.mime_type, - } - - def to_markdown(self) -> str: - """ - Convert file to markdown - :return: - """ - preview_url = self.preview_url - if self.type == FileType.IMAGE: - text = f'![{self.filename or ""}]({preview_url})' - else: - text = f"[{self.filename or preview_url}]({preview_url})" - - return text - - @property - def data(self) -> Optional[str]: - """ - Get image data, file signed url or base64 data - depending on config MULTIMODAL_SEND_IMAGE_FORMAT - :return: - """ - return self._get_data() - - @property - def preview_url(self) -> Optional[str]: - """ - Get signed preview url - :return: - """ - return self._get_data(force_url=True) - - @property - def prompt_message_content(self) -> ImagePromptMessageContent: - if self.type == FileType.IMAGE: - image_config = self.extra_config.image_config - - return ImagePromptMessageContent( - data=self.data, - detail=ImagePromptMessageContent.DETAIL.HIGH - if image_config.get("detail") == "high" - else ImagePromptMessageContent.DETAIL.LOW, - ) - - def _get_data(self, force_url: bool = False) -> Optional[str]: - from models.model import UploadFile - - if self.type == FileType.IMAGE: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = ( - db.session.query(UploadFile) - .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id) - .first() - ) - - return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url) - elif self.transfer_method == FileTransferMethod.TOOL_FILE: - extension = self.extension - # add sign url - return ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=self.related_id, extension=extension - ) - - return None diff --git a/api/core/file/file_repository.py b/api/core/file/file_repository.py new file mode 100644 index 0000000000..975e1e72db --- /dev/null +++ b/api/core/file/file_repository.py @@ -0,0 +1,32 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models import ToolFile, UploadFile + +from .models import File + + +def get_upload_file(*, session: Session, file: File): + if file.related_id is None: + raise ValueError("Missing file related_id") + stmt = select(UploadFile).filter( + UploadFile.id == file.related_id, + UploadFile.tenant_id == file.tenant_id, + ) + record = session.scalar(stmt) + if not record: + raise ValueError(f"upload file {file.related_id} not found") + return record + + +def get_tool_file(*, session: Session, file: File): + if file.related_id is None: + raise ValueError("Missing file related_id") + stmt = select(ToolFile).filter( + ToolFile.id == file.related_id, + ToolFile.tenant_id == file.tenant_id, + ) + record = session.scalar(stmt) + if not record: + raise ValueError(f"tool file {file.related_id} not found") + return record diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py new file mode 100644 index 0000000000..12123cf3f7 --- /dev/null +++ b/api/core/file/helpers.py @@ -0,0 +1,48 @@ +import base64 +import hashlib +import hmac +import os +import time + +from configs import dify_config + + +def get_signed_file_url(upload_file_id: str) -> str: + url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + key = dify_config.SECRET_KEY.encode() + msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + +def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + +def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py deleted file mode 100644 index 641686bd7c..0000000000 --- a/api/core/file/message_file_parser.py +++ /dev/null @@ -1,243 +0,0 @@ -import re -from collections.abc import Mapping, Sequence -from typing import Any, Union -from urllib.parse import parse_qs, urlparse - -import requests - -from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar -from extensions.ext_database import db -from models.account import Account -from models.model import EndUser, MessageFile, UploadFile -from services.file_service import IMAGE_EXTENSIONS - - -class MessageFileParser: - def __init__(self, tenant_id: str, app_id: str) -> None: - self.tenant_id = tenant_id - self.app_id = app_id - - def validate_and_transform_files_arg( - self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser] - ) -> list[FileVar]: - """ - validate and transform files arg - - :param files: - :param file_extra_config: - :param user: - :return: - """ - for file in files: - if not isinstance(file, dict): - raise ValueError("Invalid file format, must be dict") - if not file.get("type"): - raise ValueError("Missing file type") - FileType.value_of(file.get("type")) - if not file.get("transfer_method"): - raise ValueError("Missing file transfer method") - FileTransferMethod.value_of(file.get("transfer_method")) - if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value: - if not file.get("url"): - raise ValueError("Missing file url") - if not file.get("url").startswith("http"): - raise ValueError("Invalid file url") - if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"): - raise ValueError("Missing file upload_file_id") - if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"): - raise ValueError("Missing file tool_file_id") - - # transform files to file objs - type_file_objs = self._to_file_objs(files, file_extra_config) - - # validate files - new_files = [] - for file_type, file_objs in type_file_objs.items(): - if file_type == FileType.IMAGE: - # parse and validate files - image_config = file_extra_config.image_config - - # check if image file feature is enabled - if not image_config: - continue - - # Validate number of files - if len(files) > image_config["number_limits"]: - raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") - - for file_obj in file_objs: - # Validate transfer method - if file_obj.transfer_method.value not in image_config["transfer_methods"]: - raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}") - - # Validate file type - if file_obj.type != FileType.IMAGE: - raise ValueError(f"Invalid file type: {file_obj.type}") - - if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: - # check remote url valid and is image - result, error = self._check_image_remote_url(file_obj.url) - if result is False: - raise ValueError(error) - elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: - # get upload file from upload_file_id - upload_file = ( - db.session.query(UploadFile) - .filter( - UploadFile.id == file_obj.related_id, - UploadFile.tenant_id == self.tenant_id, - UploadFile.created_by == user.id, - UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"), - UploadFile.extension.in_(IMAGE_EXTENSIONS), - ) - .first() - ) - - # check upload file is belong to tenant and user - if not upload_file: - raise ValueError("Invalid upload file") - - new_files.append(file_obj) - - # return all file objs - return new_files - - def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig): - """ - transform message files - - :param files: - :param file_extra_config: - :return: - """ - # transform files to file objs - type_file_objs = self._to_file_objs(files, file_extra_config) - - # return all file objs - return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - - def _to_file_objs( - self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig - ) -> dict[FileType, list[FileVar]]: - """ - transform files to file objs - - :param files: - :param file_extra_config: - :return: - """ - type_file_objs: dict[FileType, list[FileVar]] = { - # Currently only support image - FileType.IMAGE: [] - } - - if not files: - return type_file_objs - - # group by file type and convert file args or message files to FileObj - for file in files: - if isinstance(file, MessageFile): - if file.belongs_to == FileBelongsTo.ASSISTANT.value: - continue - - file_obj = self._to_file_obj(file, file_extra_config) - if file_obj.type not in type_file_objs: - continue - - type_file_objs[file_obj.type].append(file_obj) - - return type_file_objs - - def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig): - """ - transform file to file obj - - :param file: - :return: - """ - if isinstance(file, dict): - transfer_method = FileTransferMethod.value_of(file.get("transfer_method")) - if transfer_method != FileTransferMethod.TOOL_FILE: - return FileVar( - tenant_id=self.tenant_id, - type=FileType.value_of(file.get("type")), - transfer_method=transfer_method, - url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=file_extra_config, - ) - return FileVar( - tenant_id=self.tenant_id, - type=FileType.value_of(file.get("type")), - transfer_method=transfer_method, - url=None, - related_id=file.get("tool_file_id"), - extra_config=file_extra_config, - ) - else: - return FileVar( - id=file.id, - tenant_id=self.tenant_id, - type=FileType.value_of(file.type), - transfer_method=FileTransferMethod.value_of(file.transfer_method), - url=file.url, - related_id=file.upload_file_id or None, - extra_config=file_extra_config, - ) - - def _check_image_remote_url(self, url): - try: - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" - " Chrome/91.0.4472.124 Safari/537.36" - } - - def is_s3_presigned_url(url): - try: - parsed_url = urlparse(url) - if "amazonaws.com" not in parsed_url.netloc: - return False - query_params = parse_qs(parsed_url.query) - - def check_presign_v2(query_params): - required_params = ["Signature", "Expires"] - for param in required_params: - if param not in query_params: - return False - if not query_params["Expires"][0].isdigit(): - return False - signature = query_params["Signature"][0] - if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): - return False - - return True - - def check_presign_v4(query_params): - required_params = ["X-Amz-Signature", "X-Amz-Expires"] - for param in required_params: - if param not in query_params: - return False - if not query_params["X-Amz-Expires"][0].isdigit(): - return False - signature = query_params["X-Amz-Signature"][0] - if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): - return False - - return True - - return check_presign_v4(query_params) or check_presign_v2(query_params) - except Exception: - return False - - if is_s3_presigned_url(url): - response = requests.get(url, headers=headers, allow_redirects=True) - if response.status_code in {200, 304}: - return True, "" - - response = requests.head(url, headers=headers, allow_redirects=True) - if response.status_code in {200, 304}: - return True, "" - else: - return False, "URL does not exist." - except requests.RequestException as e: - return False, f"Error checking URL: {e}" diff --git a/api/core/file/models.py b/api/core/file/models.py new file mode 100644 index 0000000000..866ff3155b --- /dev/null +++ b/api/core/file/models.py @@ -0,0 +1,140 @@ +from collections.abc import Mapping, Sequence +from typing import Optional + +from pydantic import BaseModel, Field, model_validator + +from core.model_runtime.entities.message_entities import ImagePromptMessageContent + +from . import helpers +from .constants import FILE_MODEL_IDENTITY +from .enums import FileTransferMethod, FileType +from .tool_file_parser import ToolFileParser + + +class ImageConfig(BaseModel): + """ + NOTE: This part of validation is deprecated, but still used in app features "Image Upload". + """ + + number_limits: int = 0 + transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + detail: ImagePromptMessageContent.DETAIL | None = None + + +class FileExtraConfig(BaseModel): + """ + File Upload Entity. + """ + + image_config: Optional[ImageConfig] = None + allowed_file_types: Sequence[FileType] = Field(default_factory=list) + allowed_extensions: Sequence[str] = Field(default_factory=list) + allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + number_limits: int = 0 + + +class File(BaseModel): + dify_model_identity: str = FILE_MODEL_IDENTITY + + id: Optional[str] = None # message file id + tenant_id: str + type: FileType + transfer_method: FileTransferMethod + remote_url: Optional[str] = None # remote url + related_id: Optional[str] = None + filename: Optional[str] = None + extension: Optional[str] = Field(default=None, description="File extension, should contains dot") + mime_type: Optional[str] = None + size: int = -1 + _extra_config: FileExtraConfig | None = None + + def to_dict(self) -> Mapping[str, str | int | None]: + data = self.model_dump(mode="json") + return { + **data, + "url": self.generate_url(), + } + + @property + def markdown(self) -> str: + url = self.generate_url() + if self.type == FileType.IMAGE: + text = f'![{self.filename or ""}]({url})' + else: + text = f"[{self.filename or url}]({url})" + + return text + + def generate_url(self) -> Optional[str]: + if self.type == FileType.IMAGE: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.remote_url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + if self.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=self.related_id) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + assert self.related_id is not None + assert self.extension is not None + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=self.extension + ) + else: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.remote_url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + if self.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=self.related_id) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + assert self.related_id is not None + assert self.extension is not None + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=self.extension + ) + + @model_validator(mode="after") + def validate_after(self): + match self.transfer_method: + case FileTransferMethod.REMOTE_URL: + if not self.remote_url: + raise ValueError("Missing file url") + if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): + raise ValueError("Invalid file url") + case FileTransferMethod.LOCAL_FILE: + if not self.related_id: + raise ValueError("Missing file related_id") + case FileTransferMethod.TOOL_FILE: + if not self.related_id: + raise ValueError("Missing file related_id") + + # Validate the extra config. + if not self._extra_config: + return self + + if self._extra_config.allowed_file_types: + if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM: + raise ValueError(f"Invalid file type: {self.type}") + + if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions: + raise ValueError(f"Invalid file extension: {self.extension}") + + if ( + self._extra_config.allowed_upload_methods + and self.transfer_method not in self._extra_config.allowed_upload_methods + ): + raise ValueError(f"Invalid transfer method: {self.transfer_method}") + + match self.type: + case FileType.IMAGE: + # NOTE: This part of validation is deprecated, but still used in app features "Image Upload". + if not self._extra_config.image_config: + return self + # TODO: skip check if transfer_methods is empty, because many test cases are not setting this field + if ( + self._extra_config.image_config.transfer_methods + and self.transfer_method not in self._extra_config.image_config.transfer_methods + ): + raise ValueError(f"Invalid transfer method: {self.transfer_method}") + + return self diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index 1efaf5529d..a17b7be367 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,4 +1,9 @@ -tool_file_manager = {"manager": None} +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from core.tools.tool_file_manager import ToolFileManager + +tool_file_manager: dict[str, Any] = {"manager": None} class ToolFileParser: diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py deleted file mode 100644 index a8c1fd4d02..0000000000 --- a/api/core/file/upload_file_parser.py +++ /dev/null @@ -1,79 +0,0 @@ -import base64 -import hashlib -import hmac -import logging -import os -import time -from typing import Optional - -from configs import dify_config -from extensions.ext_storage import storage - -IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] -IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) - - -class UploadFileParser: - @classmethod - def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: - if not upload_file: - return None - - if upload_file.extension not in IMAGE_EXTENSIONS: - return None - - if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url: - return cls.get_signed_temp_image_url(upload_file.id) - else: - # get image file base64 - try: - data = storage.load(upload_file.key) - except FileNotFoundError: - logging.error(f"File not found: {upload_file.key}") - return None - - encoded_string = base64.b64encode(data).decode("utf-8") - return f"data:{upload_file.mime_type};base64,{encoded_string}" - - @classmethod - def get_signed_temp_image_url(cls, upload_file_id) -> str: - """ - get signed url from upload file - - :param upload_file: UploadFile object - :return: - """ - base_url = dify_config.FILES_URL - image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" - - @classmethod - def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - """ - verify signature - - :param upload_file_id: file id - :param timestamp: timestamp - :param nonce: nonce - :param sign: signature - :return: - """ - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - # verify signature - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 4e6d58904e..6793e41978 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -13,8 +13,11 @@ SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "") SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "") SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3")) -proxies = ( - {"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL} +proxy_mounts = ( + { + "http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL), + "https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL), + } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None ) @@ -33,11 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): while retries <= max_retries: try: if SSRF_PROXY_ALL_URL: - response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) - elif proxies: - response = httpx.request(method=method, url=url, proxies=proxies, **kwargs) + with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client: + response = client.request(method=method, url=url, **kwargs) + elif proxy_mounts: + with httpx.Client(mounts=proxy_mounts) as client: + response = client.request(method=method, url=url, **kwargs) else: - response = httpx.request(method=method, url=url, **kwargs) + with httpx.Client() as client: + response = client.request(method=method, url=url, **kwargs) if response.status_code not in STATUS_FORCELIST: return response diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index bc94912c1e..189d94e290 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,18 +1,20 @@ from typing import Optional from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.file.message_file_parser import MessageFileParser +from core.file import file_manager from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, + PromptMessageContent, PromptMessageRole, TextPromptMessageContent, UserPromptMessage, ) from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db +from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import WorkflowRun @@ -65,13 +67,12 @@ class TokenBufferMemory: messages = list(reversed(thread_messages)) - message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id) prompt_messages = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) else: if message.workflow_run_id: @@ -84,17 +85,21 @@ class TokenBufferMemory: workflow_run.workflow.features_dict, is_vision=False ) - if file_extra_config: - file_objs = message_file_parser.transform_message_files(files, file_extra_config) + if file_extra_config and app_record: + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config + ) else: file_objs = [] if not file_objs: prompt_messages.append(UserPromptMessage(content=message.query)) else: - prompt_message_contents = [TextPromptMessageContent(data=message.query)] + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) for file_obj in file_objs: - prompt_message_contents.append(file_obj.prompt_message_content) + prompt_message = file_manager.to_prompt_message_content(file_obj) + prompt_message_contents.append(prompt_message) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index e394233d2c..e21449ec24 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,7 +1,7 @@ import logging import os -from collections.abc import Callable, Generator, Sequence -from typing import IO, Optional, Union, cast +from collections.abc import Callable, Generator, Iterable, Sequence +from typing import IO, Any, Optional, Union, cast from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle @@ -274,7 +274,7 @@ class ModelInstance: user=user, ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str: + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: """ Invoke large language tts model @@ -298,7 +298,7 @@ class ModelInstance: voice=voice, ) - def _round_robin_invoke(self, function: Callable, *args, **kwargs): + def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): """ Round-robin invoke :param function: function to invoke diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index e69de29bb2..b3eb4d4dfe 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -0,0 +1,38 @@ +from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from .message_entities import ( + AssistantPromptMessage, + AudioPromptMessageContent, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from .model_entities import ModelPropertyKey + +__all__ = [ + "ImagePromptMessageContent", + "PromptMessage", + "PromptMessageRole", + "LLMUsage", + "ModelPropertyKey", + "AssistantPromptMessage", + "PromptMessage", + "PromptMessageContent", + "PromptMessageRole", + "SystemPromptMessage", + "TextPromptMessageContent", + "UserPromptMessage", + "PromptMessageTool", + "ToolPromptMessage", + "PromptMessageContentType", + "LLMResult", + "LLMResultChunk", + "LLMResultChunkDelta", + "AudioPromptMessageContent", +] diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index e51bb18deb..cda1639661 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -2,7 +2,7 @@ from abc import ABC from enum import Enum from typing import Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator class PromptMessageRole(Enum): @@ -55,6 +55,7 @@ class PromptMessageContentType(Enum): TEXT = "text" IMAGE = "image" + AUDIO = "audio" class PromptMessageContent(BaseModel): @@ -74,12 +75,18 @@ class TextPromptMessageContent(PromptMessageContent): type: PromptMessageContentType = PromptMessageContentType.TEXT +class AudioPromptMessageContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.AUDIO + data: str = Field(..., description="Base64 encoded audio data") + format: str = Field(..., description="Audio format") + + class ImagePromptMessageContent(PromptMessageContent): """ Model class for image prompt message content. """ - class DETAIL(Enum): + class DETAIL(str, Enum): LOW = "low" HIGH = "high" diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 0027411a6e..5b6f96129b 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -1,5 +1,4 @@ import logging -import os import re import time from abc import abstractmethod @@ -8,6 +7,7 @@ from typing import Optional, Union from pydantic import ConfigDict +from configs import dify_config from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.logging_callback import LoggingCallback from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -77,7 +77,7 @@ class LargeLanguageModel(AIModel): callbacks = callbacks or [] - if bool(os.environ.get("DEBUG", "False").lower() == "true"): + if dify_config.DEBUG: callbacks.append(LoggingCallback()) # trigger before invoke callbacks @@ -107,7 +107,16 @@ class LargeLanguageModel(AIModel): callbacks=callbacks, ) else: - result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + result = self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) except Exception as e: self._trigger_invoke_error_callbacks( model=model, diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 862ec29daf..b394ea4e9d 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -1,6 +1,7 @@ import logging import re from abc import abstractmethod +from collections.abc import Iterable from typing import Any, Optional from pydantic import ConfigDict @@ -22,8 +23,14 @@ class TTSModel(AIModel): model_config = ConfigDict(protected_namespaces=()) def invoke( - self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None - ): + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + user: Optional[str] = None, + ) -> Iterable[bytes]: """ Invoke large language model @@ -50,8 +57,14 @@ class TTSModel(AIModel): @abstractmethod def _invoke( - self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None - ): + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + user: Optional[str] = None, + ) -> Iterable[bytes]: """ Invoke large language model @@ -68,25 +81,25 @@ class TTSModel(AIModel): def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: """ - Get voice for given tts model voices + Retrieves the list of voices supported by a given text-to-speech (TTS) model. - :param language: tts language - :param model: model name - :param credentials: model credentials - :return: voices lists + :param language: The language for which the voices are requested. + :param model: The name of the TTS model. + :param credentials: The credentials required to access the TTS model. + :return: A list of voices supported by the TTS model. """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties: - voices = model_schema.model_properties[ModelPropertyKey.VOICES] - if language: - return [ - {"name": d["name"], "value": d["mode"]} - for d in voices - if language and language in d.get("language") - ] - else: - return [{"name": d["name"], "value": d["mode"]} for d in voices] + if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties: + raise ValueError("this model does not support voice") + + voices = model_schema.model_properties[ModelPropertyKey.VOICES] + if language: + return [ + {"name": d["name"], "value": d["mode"]} for d in voices if language and language in d.get("language") + ] + else: + return [{"name": d["name"], "value": d["mode"]} for d in voices] def _get_model_default_voice(self, model: str, credentials: dict) -> Any: """ @@ -111,8 +124,10 @@ class TTSModel(AIModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] + if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties: + raise ValueError("this model does not support audio type") + + return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] def _get_model_word_limit(self, model: str, credentials: dict) -> int: """ @@ -121,8 +136,10 @@ class TTSModel(AIModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] + if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties: + raise ValueError("this model does not support word limit") + + return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] def _get_model_workers_limit(self, model: str, credentials: dict) -> int: """ @@ -131,8 +148,10 @@ class TTSModel(AIModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] + if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties: + raise ValueError("this model does not support max workers") + + return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] @staticmethod def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"): diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index 7501bc1164..b7c25ecb16 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -1,3 +1,4 @@ +- gpt-4o-audio-preview - gpt-4 - gpt-4o - gpt-4o-2024-05-13 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml new file mode 100644 index 0000000000..256e87edbe --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml @@ -0,0 +1,44 @@ +model: gpt-4o-audio-preview +label: + zh_Hans: gpt-4o-audio-preview + en_US: gpt-4o-audio-preview +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '5.00' + output: '15.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 1ac3837ad3..922e5e1314 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast import tiktoken from openai import OpenAI, Stream @@ -11,9 +11,9 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho from openai.types.chat.chat_completion_message import FunctionCall from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -23,6 +23,7 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -613,6 +614,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # clear illegal prompt messages prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) + # o1 compatibility block_as_stream = False if model.startswith("o1"): if stream: @@ -626,8 +628,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): del extra_model_kwargs["stop"] # chat model + messages: Any = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] response = client.chat.completions.create( - messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], + messages=messages, model=model, stream=stream, **model_parameters, @@ -946,23 +949,29 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): Convert PromptMessage to dict for OpenAI API """ if isinstance(message, UserPromptMessage): - message = cast(UserPromptMessage, message) if isinstance(message.content, str): message_dict = {"role": "user", "content": message.content} - else: + elif isinstance(message.content, list): sub_messages = [] for message_content in message.content: - if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + if isinstance(message_content, TextPromptMessageContent): sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) - elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) + elif isinstance(message_content, ImagePromptMessageContent): sub_message_dict = { "type": "image_url", "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) + elif isinstance(message_content, AudioPromptMessageContent): + sub_message_dict = { + "type": "input_audio", + "input_audio": { + "data": message_content.data, + "format": message_content.format, + }, + } + sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} elif isinstance(message, AssistantPromptMessage): diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 0200f4a32d..764944f799 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -358,8 +358,8 @@ class TraceTask: workflow_run_id = workflow_run.id workflow_run_elapsed_time = workflow_run.elapsed_time workflow_run_status = workflow_run.status - workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {} - workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {} + workflow_run_inputs = workflow_run.inputs_dict + workflow_run_outputs = workflow_run.outputs_dict workflow_run_version = workflow_run.version error = workflow_run.error or "" diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index ce8038d14e..bbd9531b19 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,12 +1,15 @@ -from typing import Optional, Union +from collections.abc import Sequence +from typing import Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file.file_obj import FileVar +from core.file import file_manager +from core.file.models import File from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContent, PromptMessageRole, SystemPromptMessage, TextPromptMessageContent, @@ -14,8 +17,8 @@ from core.model_runtime.entities.message_entities import ( ) from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform -from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.variable_pool import VariablePool class AdvancedPromptTransform(PromptTransform): @@ -28,22 +31,19 @@ class AdvancedPromptTransform(PromptTransform): def get_prompt( self, - prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], - inputs: dict, + *, + prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate, + inputs: dict[str, str], query: str, - files: list[FileVar], + files: Sequence[File], context: Optional[str], memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None, ) -> list[PromptMessage]: - inputs = {key: str(value) for key, value in inputs.items()} - prompt_messages = [] - model_mode = ModelMode.value_of(model_config.mode) - if model_mode == ModelMode.COMPLETION: + if isinstance(prompt_template, CompletionModelPromptTemplate): prompt_messages = self._get_completion_model_prompt_messages( prompt_template=prompt_template, inputs=inputs, @@ -54,12 +54,11 @@ class AdvancedPromptTransform(PromptTransform): memory=memory, model_config=model_config, ) - elif model_mode == ModelMode.CHAT: + elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template): prompt_messages = self._get_chat_model_prompt_messages( prompt_template=prompt_template, inputs=inputs, query=query, - query_prompt_template=query_prompt_template, files=files, context=context, memory_config=memory_config, @@ -74,7 +73,7 @@ class AdvancedPromptTransform(PromptTransform): prompt_template: CompletionModelPromptTemplate, inputs: dict, query: Optional[str], - files: list[FileVar], + files: Sequence[File], context: Optional[str], memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], @@ -88,10 +87,10 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages = [] if prompt_template.edition_type == "basic" or not prompt_template.edition_type: - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) if memory and memory_config: role_prefix = memory_config.role_prefix @@ -100,15 +99,15 @@ class AdvancedPromptTransform(PromptTransform): memory_config=memory_config, raw_prompt=raw_prompt, role_prefix=role_prefix, - prompt_template=prompt_template, + parser=parser, prompt_inputs=prompt_inputs, model_config=model_config, ) if query: - prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) + prompt_inputs = self._set_query_variable(query, parser, prompt_inputs) - prompt = prompt_template.format(prompt_inputs) + prompt = parser.format(prompt_inputs) else: prompt = raw_prompt prompt_inputs = inputs @@ -116,9 +115,10 @@ class AdvancedPromptTransform(PromptTransform): prompt = Jinja2Formatter.format(prompt, prompt_inputs) if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: @@ -131,35 +131,38 @@ class AdvancedPromptTransform(PromptTransform): prompt_template: list[ChatModelMessage], inputs: dict, query: Optional[str], - files: list[FileVar], + files: Sequence[File], context: Optional[str], memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None, ) -> list[PromptMessage]: """ Get chat model prompt messages. """ - raw_prompt_list = prompt_template - prompt_messages = [] - - for prompt_item in raw_prompt_list: + for prompt_item in prompt_template: raw_prompt = prompt_item.text if prompt_item.edition_type == "basic" or not prompt_item.edition_type: - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = prompt_template.format(prompt_inputs) + if self.with_variable_tmpl: + vp = VariablePool() + for k, v in inputs.items(): + if k.startswith("#"): + vp.add(k[1:-1].split("."), v) + raw_prompt = raw_prompt.replace("{{#context#}}", context or "") + prompt = vp.convert_template(raw_prompt).text + else: + parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs = self._set_context_variable( + context=context, parser=parser, prompt_inputs=prompt_inputs + ) + prompt = parser.format(prompt_inputs) elif prompt_item.edition_type == "jinja2": prompt = raw_prompt prompt_inputs = inputs - - prompt = Jinja2Formatter.format(prompt, prompt_inputs) + prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs) else: raise ValueError(f"Invalid edition type: {prompt_item.edition_type}") @@ -170,25 +173,25 @@ class AdvancedPromptTransform(PromptTransform): elif prompt_item.role == PromptMessageRole.ASSISTANT: prompt_messages.append(AssistantPromptMessage(content=prompt)) - if query and query_prompt_template: - prompt_template = PromptTemplateParser( - template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl + if query and memory_config and memory_config.query_prompt_template: + parser = PromptTemplateParser( + template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl ) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs["#sys.query#"] = query - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) - query = prompt_template.format(prompt_inputs) + query = parser.format(prompt_inputs) if memory and memory_config: prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) - if files: - prompt_message_contents = [TextPromptMessageContent(data=query)] + if files and query is not None: + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=query)) for file in files: - prompt_message_contents.append(file.prompt_message_content) - + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_messages.append(UserPromptMessage(content=query)) @@ -200,19 +203,19 @@ class AdvancedPromptTransform(PromptTransform): # get last user message content and add files prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) last_message.content = prompt_message_contents else: prompt_message_contents = [TextPromptMessageContent(data="")] # not for query for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_message_contents = [TextPromptMessageContent(data=query)] for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) elif query: @@ -220,8 +223,8 @@ class AdvancedPromptTransform(PromptTransform): return prompt_messages - def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if "#context#" in prompt_template.variable_keys: + def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + if "#context#" in parser.variable_keys: if context: prompt_inputs["#context#"] = context else: @@ -229,8 +232,8 @@ class AdvancedPromptTransform(PromptTransform): return prompt_inputs - def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if "#query#" in prompt_template.variable_keys: + def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + if "#query#" in parser.variable_keys: if query: prompt_inputs["#query#"] = query else: @@ -244,16 +247,16 @@ class AdvancedPromptTransform(PromptTransform): memory_config: MemoryConfig, raw_prompt: str, role_prefix: MemoryConfig.RolePrefix, - prompt_template: PromptTemplateParser, + parser: PromptTemplateParser, prompt_inputs: dict, model_config: ModelConfigWithCredentialsEntity, ) -> dict: - if "#histories#" in prompt_template.variable_keys: + if "#histories#" in parser.variable_keys: if memory: inputs = {"#histories#": "", **prompt_inputs} - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs)) + parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs)) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 7479560520..5a3481b963 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Optional from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.file import file_manager from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( PromptMessage, + PromptMessageContent, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, @@ -18,10 +20,10 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: - from core.file.file_obj import FileVar + from core.file.models import File -class ModelMode(enum.Enum): +class ModelMode(str, enum.Enum): COMPLETION = "completion" CHAT = "chat" @@ -53,7 +55,7 @@ class SimplePromptTransform(PromptTransform): prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: list["FileVar"], + files: list["File"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, @@ -169,7 +171,7 @@ class SimplePromptTransform(PromptTransform): inputs: dict, query: str, context: Optional[str], - files: list["FileVar"], + files: list["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -214,7 +216,7 @@ class SimplePromptTransform(PromptTransform): inputs: dict, query: str, context: Optional[str], - files: list["FileVar"], + files: list["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -261,11 +263,12 @@ class SimplePromptTransform(PromptTransform): return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage: if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_message = UserPromptMessage(content=prompt_message_contents) else: diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py index e8b626499f..f7aef76c87 100644 --- a/api/core/prompt/utils/extract_thread_messages.py +++ b/api/core/prompt/utils/extract_thread_messages.py @@ -1,7 +1,9 @@ +from typing import Any + from constants import UUID_NIL -def extract_thread_messages(messages: list[dict]) -> list[dict]: +def extract_thread_messages(messages: list[Any]): thread_messages = [] next_message = None diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 29494db221..5eec5e3c99 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,7 +1,8 @@ from typing import cast -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -21,7 +22,7 @@ class PromptMessageUtil: :return: """ prompts = [] - if model_mode == ModelMode.CHAT.value: + if model_mode == ModelMode.CHAT: tool_calls = [] for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: @@ -51,11 +52,9 @@ class PromptMessageUtil: files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: - if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) + if isinstance(content, TextPromptMessageContent): text += content.data - else: - content = cast(ImagePromptMessageContent, content) + elif isinstance(content, ImagePromptMessageContent): files.append( { "type": "image", @@ -63,6 +62,14 @@ class PromptMessageUtil: "detail": content.detail.value, } ) + elif isinstance(content, AudioPromptMessageContent): + files.append( + { + "type": "audio", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "format": content.format, + } + ) else: text = prompt_message.content diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 7352ef378b..2b6e048652 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -121,7 +121,7 @@ class WordExtractor(BaseExtractor): db.session.add(upload_file) db.session.commit() image_map[rel.target_part] = ( - f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" + f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/file-preview)" ) return image_map diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index a0494adc60..68fab0c127 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.llm import LLMNode PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index b988a588e9..b1db559441 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -32,8 +32,8 @@ class UserToolProvider(BaseModel): original_credentials: Optional[dict] = None is_team_authorization: bool = False allow_delete: bool = True - tools: list[UserTool] = None - labels: list[str] = None + tools: list[UserTool] | None = None + labels: list[str] | None = None def to_dict(self) -> dict: # ------------- @@ -42,7 +42,7 @@ class UserToolProvider(BaseModel): for tool in tools: if tool.get("parameters"): for parameter in tool.get("parameters"): - if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value: + if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: parameter["type"] = "files" # ------------- diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 9962b559fa..9a31e673d3 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -104,14 +104,15 @@ class ToolInvokeMessage(BaseModel): BLOB = "blob" JSON = "json" IMAGE_LINK = "image_link" - FILE_VAR = "file_var" + FILE = "file" type: MessageType = MessageType.TEXT """ plain text, image url or link url """ message: str | bytes | dict | None = None - meta: dict[str, Any] | None = None + # TODO: Use a BaseModel for meta + meta: dict[str, Any] = Field(default_factory=dict) save_as: str = "" @@ -143,6 +144,67 @@ class ToolParameter(BaseModel): SELECT = "select" SECRET_INPUT = "secret-input" FILE = "file" + FILES = "files" + + # deprecated, should not use. + SYSTEM_FILES = "systme-files" + + def as_normal_type(self): + if self in { + ToolParameter.ToolParameterType.SECRET_INPUT, + ToolParameter.ToolParameterType.SELECT, + }: + return "string" + return self.value + + def cast_value(self, value: Any, /): + try: + match self: + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): + if value is None: + return "" + else: + return value if isinstance(value, str) else str(value) + + case ToolParameter.ToolParameterType.BOOLEAN: + if value is None: + return False + elif isinstance(value, str): + # Allowed YAML boolean value strings: https://yaml.org/type/bool.html + # and also '0' for False and '1' for True + match value.lower(): + case "true" | "yes" | "y" | "1": + return True + case "false" | "no" | "n" | "0": + return False + case _: + return bool(value) + else: + return value if isinstance(value, bool) else bool(value) + + case ToolParameter.ToolParameterType.NUMBER: + if isinstance(value, int | float): + return value + elif isinstance(value, str) and value: + if "." in value: + return float(value) + else: + return int(value) + case ( + ToolParameter.ToolParameterType.SYSTEM_FILES + | ToolParameter.ToolParameterType.FILE + | ToolParameter.ToolParameterType.FILES + ): + return value + case _: + return str(value) + + except Exception: + raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.") class ToolParameterForm(Enum): SCHEMA = "schema" # should be set while adding tool diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index a8c647d71e..af9aa6abb4 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool): for image in response.data: mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) blob_message = self.create_blob_message( - blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value + blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE ) result.append(blob_message) return result diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py index 396570248a..3173fb9e13 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py @@ -2,7 +2,7 @@ from typing import Any from duckduckgo_search import DDGS -from core.file.file_obj import FileTransferMethod +from core.file.models import FileTransferMethod from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/podcast_generator/_assets/icon.svg b/api/core/tools/provider/builtin/podcast_generator/_assets/icon.svg new file mode 100644 index 0000000000..01743c9cd3 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/_assets/icon.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py new file mode 100644 index 0000000000..0b9c025834 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py @@ -0,0 +1,33 @@ +from typing import Any + +import openai + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class PodcastGeneratorProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + tts_service = credentials.get("tts_service") + api_key = credentials.get("api_key") + + if not tts_service: + raise ToolProviderCredentialValidationError("TTS service is not specified") + + if not api_key: + raise ToolProviderCredentialValidationError("API key is missing") + + if tts_service == "openai": + self._validate_openai_credentials(api_key) + else: + raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}") + + def _validate_openai_credentials(self, api_key: str) -> None: + client = openai.OpenAI(api_key=api_key) + try: + # We're using a simple API call to validate the credentials + client.models.list() + except openai.AuthenticationError: + raise ToolProviderCredentialValidationError("Invalid OpenAI API key") + except Exception as e: + raise ToolProviderCredentialValidationError(f"Error validating OpenAI API key: {str(e)}") diff --git a/api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml new file mode 100644 index 0000000000..bd02b32020 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml @@ -0,0 +1,34 @@ +identity: + author: Dify + name: podcast_generator + label: + en_US: Podcast Generator + zh_Hans: 播客生成器 + description: + en_US: Generate podcast audio using Text-to-Speech services + zh_Hans: 使用文字转语音服务生成播客音频 + icon: icon.svg +credentials_for_provider: + tts_service: + type: select + required: true + label: + en_US: TTS Service + zh_Hans: TTS 服务 + placeholder: + en_US: Select a TTS service + zh_Hans: 选择一个 TTS 服务 + options: + - label: + en_US: OpenAI TTS + zh_Hans: OpenAI TTS + value: openai + api_key: + type: secret-input + required: true + label: + en_US: API Key + zh_Hans: API 密钥 + placeholder: + en_US: Enter your TTS service API key + zh_Hans: 输入您的 TTS 服务 API 密钥 diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py new file mode 100644 index 0000000000..8c8dd9bf68 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py @@ -0,0 +1,100 @@ +import concurrent.futures +import io +import random +from typing import Any, Literal, Optional, Union + +import openai +from pydub import AudioSegment + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + + +class PodcastAudioGeneratorTool(BuiltinTool): + @staticmethod + def _generate_silence(duration: float): + # Generate silent WAV data using pydub + silence = AudioSegment.silent(duration=int(duration * 1000)) # pydub uses milliseconds + return silence + + @staticmethod + def _generate_audio_segment( + client: openai.OpenAI, + line: str, + voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"], + index: int, + ) -> tuple[int, Union[AudioSegment, str], Optional[AudioSegment]]: + try: + response = client.audio.speech.create(model="tts-1", voice=voice, input=line.strip(), response_format="wav") + audio = AudioSegment.from_wav(io.BytesIO(response.content)) + silence_duration = random.uniform(0.1, 1.5) + silence = PodcastAudioGeneratorTool._generate_silence(silence_duration) + return index, audio, silence + except Exception as e: + return index, f"Error generating audio: {str(e)}", None + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + # Extract parameters + script = tool_parameters.get("script", "") + host1_voice = tool_parameters.get("host1_voice") + host2_voice = tool_parameters.get("host2_voice") + + # Split the script into lines + script_lines = [line for line in script.split("\n") if line.strip()] + + # Ensure voices are provided + if not host1_voice or not host2_voice: + raise ToolParameterValidationError("Host voices are required") + + # Get OpenAI API key from credentials + if not self.runtime or not self.runtime.credentials: + raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing") + api_key = self.runtime.credentials.get("api_key") + if not api_key: + raise ToolProviderCredentialValidationError("OpenAI API key is missing") + + # Initialize OpenAI client + client = openai.OpenAI(api_key=api_key) + + # Create a thread pool + max_workers = 5 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for i, line in enumerate(script_lines): + voice = host1_voice if i % 2 == 0 else host2_voice + future = executor.submit(self._generate_audio_segment, client, line, voice, i) + futures.append(future) + + # Collect results + audio_segments: list[Any] = [None] * len(script_lines) + for future in concurrent.futures.as_completed(futures): + index, audio, silence = future.result() + if isinstance(audio, str): # Error occurred + return self.create_text_message(audio) + audio_segments[index] = (audio, silence) + + # Combine audio segments in the correct order + combined_audio = AudioSegment.empty() + for i, (audio, silence) in enumerate(audio_segments): + if audio: + combined_audio += audio + if i < len(audio_segments) - 1 and silence: + combined_audio += silence + + # Export the combined audio to a WAV file in memory + buffer = io.BytesIO() + combined_audio.export(buffer, format="wav") + wav_bytes = buffer.getvalue() + + # Create a blob message with the combined audio + return [ + self.create_text_message("Audio generated successfully"), + self.create_blob_message( + blob=wav_bytes, + meta={"mime_type": "audio/x-wav"}, + save_as=self.VariableKey.AUDIO, + ), + ] diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.yaml b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.yaml new file mode 100644 index 0000000000..d6ae98f595 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.yaml @@ -0,0 +1,95 @@ +identity: + name: podcast_audio_generator + author: Dify + label: + en_US: Podcast Audio Generator + zh_Hans: 播客音频生成器 +description: + human: + en_US: Generate a podcast audio file from a script with two alternating voices using OpenAI's TTS service. + zh_Hans: 使用 OpenAI 的 TTS 服务,从包含两个交替声音的脚本生成播客音频文件。 + llm: This tool converts a prepared podcast script into an audio file using OpenAI's Text-to-Speech service, with two specified voices for alternating hosts. +parameters: + - name: script + type: string + required: true + label: + en_US: Podcast Script + zh_Hans: 播客脚本 + human_description: + en_US: A string containing alternating lines for two hosts, separated by newline characters. + zh_Hans: 包含两位主持人交替台词的字符串,每行用换行符分隔。 + llm_description: A string representing the script, with alternating lines for two hosts separated by newline characters. + form: llm + - name: host1_voice + type: select + required: true + label: + en_US: Host 1 Voice + zh_Hans: 主持人1 音色 + human_description: + en_US: The voice for the first host. + zh_Hans: 第一位主持人的音色。 + llm_description: The voice identifier for the first host's voice. + options: + - label: + en_US: Alloy + zh_Hans: Alloy + value: alloy + - label: + en_US: Echo + zh_Hans: Echo + value: echo + - label: + en_US: Fable + zh_Hans: Fable + value: fable + - label: + en_US: Onyx + zh_Hans: Onyx + value: onyx + - label: + en_US: Nova + zh_Hans: Nova + value: nova + - label: + en_US: Shimmer + zh_Hans: Shimmer + value: shimmer + form: form + - name: host2_voice + type: select + required: true + label: + en_US: Host 2 Voice + zh_Hans: 主持人2 音色 + human_description: + en_US: The voice for the second host. + zh_Hans: 第二位主持人的音色。 + llm_description: The voice identifier for the second host's voice. + options: + - label: + en_US: Alloy + zh_Hans: Alloy + value: alloy + - label: + en_US: Echo + zh_Hans: Echo + value: echo + - label: + en_US: Fable + zh_Hans: Fable + value: fable + - label: + en_US: Onyx + zh_Hans: Onyx + value: onyx + - label: + en_US: Nova + zh_Hans: Nova + value: nova + - label: + en_US: Shimmer + zh_Hans: Shimmer + value: shimmer + form: form diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index ff022812ef..955a0add3b 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -13,7 +13,6 @@ from core.tools.errors import ( from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool -from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.yaml_utils import load_yaml_file @@ -208,9 +207,7 @@ class BuiltinToolProviderController(ToolProviderController): # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - default_value = ToolParameterConverter.cast_parameter_by_type( - parameter_schema.default, parameter_schema.type - ) + default_value = parameter_schema.type.cast_value(parameter_schema.default) tool_parameters[parameter] = default_value def validate_credentials(self, credentials: dict[str, Any]) -> None: diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index 321b212014..bc05a11562 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -11,7 +11,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.tool import Tool -from core.tools.utils.tool_parameter_converter import ToolParameterConverter class ToolProviderController(BaseModel, ABC): @@ -127,9 +126,7 @@ class ToolProviderController(BaseModel, ABC): # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type( - parameter_schema.default, parameter_schema.type - ) + tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default) def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index 25eaf6a66a..5656dd09ab 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -1,6 +1,6 @@ from typing import Optional -from core.app.app_config.entities import VariableEntity, VariableEntityType +from core.app.app_config.entities import VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -23,6 +23,8 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = { VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING, VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT, VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER, + VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE, + VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES, } @@ -36,8 +38,8 @@ class WorkflowToolProviderController(ToolProviderController): if not app: raise ValueError("app not found") - controller = WorkflowToolProviderController( - **{ + controller = WorkflowToolProviderController.model_validate( + { "identity": { "author": db_provider.user.name if db_provider.user_id and db_provider.user else "", "name": db_provider.label, @@ -67,7 +69,7 @@ class WorkflowToolProviderController(ToolProviderController): :param app: the app :return: the tool """ - workflow: Workflow = ( + workflow = ( db.session.query(Workflow) .filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) .first() @@ -76,14 +78,14 @@ class WorkflowToolProviderController(ToolProviderController): raise ValueError("workflow not found") # fetch start node - graph: dict = workflow.graph_dict - features_dict: dict = workflow.features_dict + graph = workflow.graph_dict + features_dict = workflow.features_dict features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW) parameters = db_provider.parameter_configurations variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) - def fetch_workflow_variable(variable_name: str) -> VariableEntity: + def fetch_workflow_variable(variable_name: str): return next(filter(lambda x: x.variable == variable_name, variables), None) user = db_provider.user @@ -114,7 +116,6 @@ class WorkflowToolProviderController(ToolProviderController): llm_description=parameter.description, required=variable.required, options=options, - default=variable.default, ) ) elif features.file_upload: @@ -123,7 +124,7 @@ class WorkflowToolProviderController(ToolProviderController): name=parameter.name, label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name), human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), - type=ToolParameter.ToolParameterType.FILE, + type=ToolParameter.ToolParameterType.SYSTEM_FILES, llm_description=parameter.description, required=False, form=parameter.form, diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index cb4ab51ceb..6cb6e18b6d 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -20,10 +20,9 @@ from core.tools.entities.tool_entities import ( ToolRuntimeVariablePool, ) from core.tools.tool_file_manager import ToolFileManager -from core.tools.utils.tool_parameter_converter import ToolParameterConverter if TYPE_CHECKING: - from core.file.file_obj import FileVar + from core.file.models import File class Tool(BaseModel, ABC): @@ -63,8 +62,12 @@ class Tool(BaseModel, ABC): def __init__(self, **data: Any): super().__init__(**data) - class VariableKey(Enum): + class VariableKey(str, Enum): IMAGE = "image" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" + CUSTOM = "custom" def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ @@ -221,9 +224,7 @@ class Tool(BaseModel, ABC): result = deepcopy(tool_parameters) for parameter in self.parameters or []: if parameter.name in tool_parameters: - result[parameter.name] = ToolParameterConverter.cast_parameter_by_type( - tool_parameters[parameter.name], parameter.type - ) + result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) return result @@ -295,10 +296,8 @@ class Tool(BaseModel, ABC): """ return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as) - def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as="" - ) + def create_file_message(self, file: "File") -> ToolInvokeMessage: + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="") def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage: """ diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index a885b8784f..2ab72213ff 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -3,7 +3,7 @@ import logging from copy import deepcopy from typing import Any, Optional, Union -from core.file.file_obj import FileTransferMethod, FileVar +from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType from core.tools.tool.tool import Tool from extensions.ext_database import db @@ -45,11 +45,13 @@ class WorkflowTool(Tool): workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version) # transform the tool parameters - tool_parameters, files = self._transform_args(tool_parameters) + tool_parameters, files = self._transform_args(tool_parameters=tool_parameters) from core.app.apps.workflow.app_generator import WorkflowAppGenerator generator = WorkflowAppGenerator() + assert self.runtime is not None + assert self.runtime.invoke_from is not None result = generator.generate( app_model=app, workflow=workflow, @@ -74,7 +76,7 @@ class WorkflowTool(Tool): else: outputs, files = self._extract_files(outputs) for file in files: - result.append(self.create_file_var_message(file)) + result.append(self.create_file_message(file)) result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) result.append(self.create_json_message(outputs)) @@ -154,22 +156,22 @@ class WorkflowTool(Tool): parameters_result = {} files = [] for parameter in parameter_rules: - if parameter.type == ToolParameter.ToolParameterType.FILE: + if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES: file = tool_parameters.get(parameter.name) if file: try: - file_var_list = [FileVar(**f) for f in file] - for file_var in file_var_list: - file_dict = { - "transfer_method": file_var.transfer_method.value, - "type": file_var.type.value, + file_var_list = [File.model_validate(f) for f in file] + for file in file_var_list: + file_dict: dict[str, str | None] = { + "transfer_method": file.transfer_method.value, + "type": file.type.value, } - if file_var.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file_var.related_id - elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file_var.related_id - elif file_var.transfer_method == FileTransferMethod.REMOTE_URL: - file_dict["url"] = file_var.preview_url + if file.transfer_method == FileTransferMethod.TOOL_FILE: + file_dict["tool_file_id"] = file.related_id + elif file.transfer_method == FileTransferMethod.LOCAL_FILE: + file_dict["upload_file_id"] = file.related_id + elif file.transfer_method == FileTransferMethod.REMOTE_URL: + file_dict["url"] = file.generate_url() files.append(file_dict) except Exception as e: @@ -179,7 +181,7 @@ class WorkflowTool(Tool): return parameters_result, files - def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: + def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]: """ extract files from the result @@ -190,17 +192,13 @@ class WorkflowTool(Tool): result = {} for key, value in outputs.items(): if isinstance(value, list): - has_file = False for item in value: - if isinstance(item, dict) and item.get("__variant") == "FileVar": - try: - files.append(FileVar(**item)) - has_file = True - except Exception as e: - pass - if has_file: - continue + if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY: + file = File.model_validate(item) + files.append(file) + elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + file = File.model_validate(value) + files.append(file) result[key] = value - return result, files diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 9912114dd6..9e290c3651 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -10,7 +10,8 @@ from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.file.file_obj import FileTransferMethod +from core.file import FileType +from core.file.models import FileTransferMethod from core.ops.ops_trace_manager import TraceQueueManager from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter from core.tools.errors import ( @@ -26,6 +27,7 @@ from core.tools.tool.tool import Tool from core.tools.tool.workflow_tool import WorkflowTool from core.tools.utils.message_transformer import ToolFileMessageTransformer from extensions.ext_database import db +from models.enums import CreatedByRole from models.model import Message, MessageFile @@ -128,6 +130,7 @@ class ToolEngine: """ try: # hit the callback handler + assert tool.identity is not None workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) if isinstance(tool, WorkflowTool): @@ -258,7 +261,10 @@ class ToolEngine: @staticmethod def _create_message_files( - tool_messages: list[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str + tool_messages: list[ToolInvokeMessageBinary], + agent_message: Message, + invoke_from: InvokeFrom, + user_id: str, ) -> list[tuple[Any, str]]: """ Create message file @@ -269,29 +275,31 @@ class ToolEngine: result = [] for message in tool_messages: - file_type = "bin" if "image" in message.mimetype: - file_type = "image" + file_type = FileType.IMAGE elif "video" in message.mimetype: - file_type = "video" + file_type = FileType.VIDEO elif "audio" in message.mimetype: - file_type = "audio" - elif "text" in message.mimetype: - file_type = "text" - elif "pdf" in message.mimetype: - file_type = "pdf" - elif "zip" in message.mimetype: - file_type = "archive" - # ... + file_type = FileType.AUDIO + elif "text" in message.mimetype or "pdf" in message.mimetype: + file_type = FileType.DOCUMENT + else: + file_type = FileType.CUSTOM + # extract tool file id from url + tool_file_id = message.url.split("/")[-1].split(".")[0] message_file = MessageFile( message_id=agent_message.id, type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE.value, + transfer_method=FileTransferMethod.TOOL_FILE, belongs_to="assistant", url=message.url, - upload_file_id=None, - created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"), + upload_file_id=tool_file_id, + created_by_role=( + CreatedByRole.ACCOUNT + if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatedByRole.END_USER + ), created_by=user_id, ) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index ad3b9c7328..1a28df31bc 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -4,7 +4,6 @@ import hmac import logging import os import time -from collections.abc import Generator from mimetypes import guess_extension, guess_type from typing import Optional, Union from uuid import uuid4 @@ -57,22 +56,32 @@ class ToolFileManager: @staticmethod def create_file_by_raw( - user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str], + file_binary: bytes, + mimetype: str, ) -> ToolFile: - """ - create file - """ extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f"tools/{tenant_id}/{unique_name}{extension}" - storage.save(filename, file_binary) + filename = f"{unique_name}{extension}" + filepath = f"tools/{tenant_id}/{filename}" + storage.save(filepath, file_binary) tool_file = ToolFile( - user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + name=filename, + size=len(file_binary), ) db.session.add(tool_file) db.session.commit() + db.session.refresh(tool_file) return tool_file @@ -80,29 +89,34 @@ class ToolFileManager: def create_file_by_url( user_id: str, tenant_id: str, - conversation_id: str, + conversation_id: str | None, file_url: str, ) -> ToolFile: - """ - create file - """ # try to download image - response = get(file_url) - response.raise_for_status() - blob = response.content + try: + response = get(file_url) + response.raise_for_status() + blob = response.content + except Exception as e: + logger.error(f"Failed to download file from {file_url}: {e}") + raise + mimetype = guess_type(file_url)[0] or "octet/stream" extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f"tools/{tenant_id}/{unique_name}{extension}" - storage.save(filename, blob) + filename = f"{unique_name}{extension}" + filepath = f"tools/{tenant_id}/{filename}" + storage.save(filepath, blob) tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, - file_key=filename, + file_key=filepath, mimetype=mimetype, original_url=file_url, + name=filename, + size=len(blob), ) db.session.add(tool_file) @@ -110,18 +124,6 @@ class ToolFileManager: return tool_file - @staticmethod - def create_file_by_key( - user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str - ) -> ToolFile: - """ - create file - """ - tool_file = ToolFile( - user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype - ) - return tool_file - @staticmethod def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: """ @@ -131,7 +133,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile = ( + tool_file = ( db.session.query(ToolFile) .filter( ToolFile.id == id, @@ -155,7 +157,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - message_file: MessageFile = ( + message_file = ( db.session.query(MessageFile) .filter( MessageFile.id == id, @@ -166,13 +168,16 @@ class ToolFileManager: # Check if message_file is not None if message_file is not None: # get tool file id - tool_file_id = message_file.url.split("/")[-1] - # trim extension - tool_file_id = tool_file_id.split(".")[0] + if message_file.url is not None: + tool_file_id = message_file.url.split("/")[-1] + # trim extension + tool_file_id = tool_file_id.split(".")[0] + else: + tool_file_id = None else: tool_file_id = None - tool_file: ToolFile = ( + tool_file = ( db.session.query(ToolFile) .filter( ToolFile.id == tool_file_id, @@ -188,7 +193,7 @@ class ToolFileManager: return blob, tool_file.mimetype @staticmethod - def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]: + def get_file_generator_by_tool_file_id(tool_file_id: str): """ get file binary @@ -196,7 +201,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile = ( + tool_file = ( db.session.query(ToolFile) .filter( ToolFile.id == tool_file_id, @@ -205,11 +210,11 @@ class ToolFileManager: ) if not tool_file: - return None + return None, None - generator = storage.load_stream(tool_file.file_key) + stream = storage.load_stream(tool_file.file_key) - return generator, tool_file.mimetype + return stream, tool_file # init tool_file_parser diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index ed66dd1357..9e984732b7 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -24,7 +24,6 @@ from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager -from core.tools.utils.tool_parameter_converter import ToolParameterConverter from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -203,7 +202,7 @@ class ToolManager: raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod - def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: + def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict): """ init runtime parameter """ @@ -222,7 +221,7 @@ class ToolManager: f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" ) - return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type) + return parameter_rule.type.cast_value(parameter_value) @classmethod def get_agent_tool_runtime( @@ -243,7 +242,11 @@ class ToolManager: parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: # check file types - if parameter.type == ToolParameter.ToolParameterType.FILE: + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: raise ValueError(f"file type parameter {parameter.name} not supported in agent") if parameter.form == ToolParameter.ToolParameterForm.FORM: diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 3cfab207ba..1812d24571 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,7 +1,8 @@ import logging from mimetypes import guess_extension +from typing import Optional -from core.file.file_obj import FileTransferMethod, FileType +from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager @@ -11,7 +12,7 @@ logger = logging.getLogger(__name__) class ToolFileMessageTransformer: @classmethod def transform_tool_invoke_messages( - cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str + cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str | None ) -> list[ToolInvokeMessage]: """ Transform tool message and handle file download @@ -21,7 +22,7 @@ class ToolFileMessageTransformer: for message in messages: if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: result.append(message) - elif message.type == ToolInvokeMessage.MessageType.IMAGE: + elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(message.message, str): # try to download image try: file = ToolFileManager.create_file_by_url( @@ -50,11 +51,14 @@ class ToolFileMessageTransformer: ) elif message.type == ToolInvokeMessage.MessageType.BLOB: # get mime type and save blob to storage + assert message.meta is not None mimetype = message.meta.get("mime_type", "octet/stream") # if message is str, encode it to bytes if isinstance(message.message, str): message.message = message.message.encode("utf-8") + # FIXME: should do a type check here. + assert isinstance(message.message, bytes) file = ToolFileManager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, @@ -63,7 +67,7 @@ class ToolFileMessageTransformer: mimetype=mimetype, ) - url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) + url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype)) # check if file is image if "image" in mimetype: @@ -84,12 +88,14 @@ class ToolFileMessageTransformer: meta=message.meta.copy() if message.meta is not None else {}, ) ) - elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: - file_var = message.meta.get("file_var") - if file_var: - if file_var.transfer_method == FileTransferMethod.TOOL_FILE: - url = cls.get_tool_file_url(file_var.related_id, file_var.extension) - if file_var.type == FileType.IMAGE: + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + file = message.meta.get("file") + if isinstance(file, File): + if file.transfer_method == FileTransferMethod.TOOL_FILE: + assert file.related_id is not None + url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + if file.type == FileType.IMAGE: result.append( ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, @@ -107,11 +113,13 @@ class ToolFileMessageTransformer: meta=message.meta.copy() if message.meta is not None else {}, ) ) + else: + result.append(message) else: result.append(message) return result @classmethod - def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: + def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: return f'/files/tools/{tool_file_id}{extension or ".bin"}' diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py deleted file mode 100644 index 6f7610651c..0000000000 --- a/api/core/tools/utils/tool_parameter_converter.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolParameter - - -class ToolParameterConverter: - @staticmethod - def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: - match parameter_type: - case ( - ToolParameter.ToolParameterType.STRING - | ToolParameter.ToolParameterType.SECRET_INPUT - | ToolParameter.ToolParameterType.SELECT - ): - return "string" - - case ToolParameter.ToolParameterType.BOOLEAN: - return "boolean" - - case ToolParameter.ToolParameterType.NUMBER: - return "number" - - case _: - raise ValueError(f"Unsupported parameter type {parameter_type}") - - @staticmethod - def cast_parameter_by_type(value: Any, parameter_type: str) -> Any: - # convert tool parameter config to correct type - try: - match parameter_type: - case ( - ToolParameter.ToolParameterType.STRING - | ToolParameter.ToolParameterType.SECRET_INPUT - | ToolParameter.ToolParameterType.SELECT - ): - if value is None: - return "" - else: - return value if isinstance(value, str) else str(value) - - case ToolParameter.ToolParameterType.BOOLEAN: - if value is None: - return False - elif isinstance(value, str): - # Allowed YAML boolean value strings: https://yaml.org/type/bool.html - # and also '0' for False and '1' for True - match value.lower(): - case "true" | "yes" | "y" | "1": - return True - case "false" | "no" | "n" | "0": - return False - case _: - return bool(value) - else: - return value if isinstance(value, bool) else bool(value) - - case ToolParameter.ToolParameterType.NUMBER: - if isinstance(value, int) | isinstance(value, float): - return value - elif isinstance(value, str) and value != "": - if "." in value: - return float(value) - else: - return int(value) - case ToolParameter.ToolParameterType.FILE: - return value - case _: - return str(value) - - except Exception: - raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.") diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 94d9fd9eb9..3ea07b75aa 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,19 +1,18 @@ +from collections.abc import Mapping, Sequence +from typing import Any + from core.app.app_config.entities import VariableEntity from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration class WorkflowToolConfigurationUtils: @classmethod - def check_parameter_configurations(cls, configurations: list[dict]): - """ - check parameter configurations - """ + def check_parameter_configurations(cls, configurations: Mapping[str, Any]): for configuration in configurations: - if not WorkflowToolParameterConfiguration(**configuration): - raise ValueError("invalid parameter configuration") + WorkflowToolParameterConfiguration.model_validate(configuration) @classmethod - def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]: + def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: """ get workflow graph variables """ diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 99b9f80499..42c7f85bc6 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path from typing import Any import yaml @@ -17,15 +18,18 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any :param default_value: the value returned when errors ignored :return: an object of the YAML content """ - try: - with open(file_path, encoding="utf-8") as yaml_file: - try: - yaml_content = yaml.safe_load(yaml_file) - return yaml_content or default_value - except Exception as e: - raise YAMLError(f"Failed to load YAML file {file_path}: {e}") - except Exception as e: + if not file_path or not Path(file_path).exists(): if ignore_error: return default_value else: - raise e + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, encoding="utf-8") as yaml_file: + try: + yaml_content = yaml.safe_load(yaml_file) + return yaml_content or default_value + except Exception as e: + if ignore_error: + return default_value + else: + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e diff --git a/api/core/app/segments/__init__.py b/api/core/variables/__init__.py similarity index 78% rename from api/core/app/segments/__init__.py rename to api/core/variables/__init__.py index 652ef243b4..87f9e3ed45 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/variables/__init__.py @@ -1,7 +1,12 @@ from .segment_group import SegmentGroup from .segments import ( ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, ArraySegment, + ArrayStringSegment, + FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -15,6 +20,7 @@ from .variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, + FileVariable, FloatVariable, IntegerVariable, NoneVariable, @@ -46,4 +52,10 @@ __all__ = [ "ArrayNumberVariable", "ArrayObjectVariable", "ArraySegment", + "ArrayFileSegment", + "ArrayNumberSegment", + "ArrayObjectSegment", + "ArrayStringSegment", + "FileSegment", + "FileVariable", ] diff --git a/api/core/app/segments/exc.py b/api/core/variables/exc.py similarity index 100% rename from api/core/app/segments/exc.py rename to api/core/variables/exc.py diff --git a/api/core/app/segments/segment_group.py b/api/core/variables/segment_group.py similarity index 100% rename from api/core/app/segments/segment_group.py rename to api/core/variables/segment_group.py diff --git a/api/core/app/segments/segments.py b/api/core/variables/segments.py similarity index 79% rename from api/core/app/segments/segments.py rename to api/core/variables/segments.py index b26b3c8291..782798411e 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/variables/segments.py @@ -5,6 +5,8 @@ from typing import Any from pydantic import BaseModel, ConfigDict, field_validator +from core.file import File + from .types import SegmentType @@ -39,6 +41,9 @@ class Segment(BaseModel): @property def size(self) -> int: + """ + Return the size of the value in bytes. + """ return sys.getsizeof(self.value) def to_object(self) -> Any: @@ -99,13 +104,27 @@ class ArraySegment(Segment): def markdown(self) -> str: items = [] for item in self.value: - if hasattr(item, "to_markdown"): - items.append(item.to_markdown()) - else: - items.append(str(item)) + items.append(str(item)) return "\n".join(items) +class FileSegment(Segment): + value_type: SegmentType = SegmentType.FILE + value: File + + @property + def markdown(self) -> str: + return self.value.markdown + + @property + def log(self) -> str: + return str(self.value) + + @property + def text(self) -> str: + return str(self.value) + + class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY value: Sequence[Any] @@ -124,3 +143,15 @@ class ArrayNumberSegment(ArraySegment): class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT value: Sequence[Mapping[str, Any]] + + +class ArrayFileSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_FILE + value: Sequence[File] + + @property + def markdown(self) -> str: + items = [] + for item in self.value: + items.append(item.markdown) + return "\n".join(items) diff --git a/api/core/app/segments/types.py b/api/core/variables/types.py similarity index 86% rename from api/core/app/segments/types.py rename to api/core/variables/types.py index 9cf0856df5..53c2e8a3aa 100644 --- a/api/core/app/segments/types.py +++ b/api/core/variables/types.py @@ -11,5 +11,7 @@ class SegmentType(str, Enum): ARRAY_NUMBER = "array[number]" ARRAY_OBJECT = "array[object]" OBJECT = "object" + FILE = "file" + ARRAY_FILE = "array[file]" GROUP = "group" diff --git a/api/core/app/segments/variables.py b/api/core/variables/variables.py similarity index 95% rename from api/core/app/segments/variables.py rename to api/core/variables/variables.py index f0e403ab8d..ddc6914192 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/variables/variables.py @@ -7,6 +7,7 @@ from .segments import ( ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, + FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -73,3 +74,7 @@ class SecretVariable(StringVariable): class NoneVariable(NoneSegment, Variable): value_type: SegmentType = SegmentType.NONE value: None = None + + +class FileVariable(FileSegment, Variable): + pass diff --git a/api/core/workflow/callbacks/__init__.py b/api/core/workflow/callbacks/__init__.py index e69de29bb2..403fbbaa2f 100644 --- a/api/core/workflow/callbacks/__init__.py +++ b/api/core/workflow/callbacks/__init__.py @@ -0,0 +1,7 @@ +from .base_workflow_callback import WorkflowCallback +from .workflow_logging_callback import WorkflowLoggingCallback + +__all__ = [ + "WorkflowLoggingCallback", + "WorkflowCallback", +] diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py similarity index 99% rename from api/core/app/apps/workflow_logging_callback.py rename to api/core/workflow/callbacks/workflow_logging_callback.py index 60683b0f21..17913de7b0 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -1,7 +1,6 @@ from typing import Optional from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, GraphRunFailedEvent, @@ -20,6 +19,8 @@ from core.workflow.graph_engine.entities.event import ( ParallelBranchRunSucceededEvent, ) +from .base_workflow_callback import WorkflowCallback + _TEXT_COLOR_MAPPING = { "blue": "36;1", "yellow": "33;1", diff --git a/api/core/workflow/constants.py b/api/core/workflow/constants.py new file mode 100644 index 0000000000..e3fe17c284 --- /dev/null +++ b/api/core/workflow/constants.py @@ -0,0 +1,3 @@ +SYSTEM_VARIABLE_NODE_ID = "sys" +ENVIRONMENT_VARIABLE_NODE_ID = "env" +CONVERSATION_VARIABLE_NODE_ID = "conversation" diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 5353b99ed3..0131bb342b 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,52 +1,14 @@ +from collections.abc import Mapping from enum import Enum from typing import Any, Optional from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMUsage -from models import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus -class NodeType(Enum): - """ - Node Types. - """ - - START = "start" - END = "end" - ANSWER = "answer" - LLM = "llm" - KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" - IF_ELSE = "if-else" - CODE = "code" - TEMPLATE_TRANSFORM = "template-transform" - QUESTION_CLASSIFIER = "question-classifier" - HTTP_REQUEST = "http-request" - TOOL = "tool" - VARIABLE_AGGREGATOR = "variable-aggregator" - # TODO: merge this into VARIABLE_AGGREGATOR - VARIABLE_ASSIGNER = "variable-assigner" - LOOP = "loop" - ITERATION = "iteration" - ITERATION_START = "iteration-start" # fake start node for iteration - PARAMETER_EXTRACTOR = "parameter-extractor" - CONVERSATION_VARIABLE_ASSIGNER = "assigner" - - @classmethod - def value_of(cls, value: str) -> "NodeType": - """ - Get value of given node type. - - :param value: node type value - :return: node type - """ - for node_type in cls: - if node_type.value == value: - return node_type - raise ValueError(f"invalid node type value {value}") - - -class NodeRunMetadataKey(Enum): +class NodeRunMetadataKey(str, Enum): """ Node Run Metadata Key. """ @@ -70,7 +32,7 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING - inputs: Optional[dict[str, Any]] = None # node inputs + inputs: Optional[Mapping[str, Any]] = None # node inputs process_data: Optional[dict[str, Any]] = None # process data outputs: Optional[dict[str, Any]] = None # node outputs metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata @@ -79,24 +41,3 @@ class NodeRunResult(BaseModel): edge_source_handle: Optional[str] = None # source handle id of node with multiple branches error: Optional[str] = None # error message if status is failed - - -class UserFrom(Enum): - """ - User from - """ - - ACCOUNT = "account" - END_USER = "end-user" - - @classmethod - def value_of(cls, value: str) -> "UserFrom": - """ - Value of - :param value: value - :return: - """ - for item in cls: - if item.value == value: - return item - raise ValueError(f"Invalid value: {value}") diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py index 1dfb1852f8..8f4c2d7975 100644 --- a/api/core/workflow/entities/variable_entities.py +++ b/api/core/workflow/entities/variable_entities.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + from pydantic import BaseModel @@ -7,4 +9,4 @@ class VariableSelector(BaseModel): """ variable: str - value_selector: list[str] + value_selector: Sequence[str] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index b94b7f7198..5f932c0a8e 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,20 +1,23 @@ +import re from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Union -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field from typing_extensions import deprecated -from core.app.segments import Segment, Variable, factory -from core.file.file_obj import FileVar -from core.workflow.enums import SystemVariableKey +from core.file import File, FileAttribute, file_manager +from core.variables import Segment, SegmentGroup, Variable +from core.variables.segments import FileSegment +from factories import variable_factory -VariableValue = Union[str, int, float, dict, list, FileVar] +from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from ..enums import SystemVariableKey + +VariableValue = Union[str, int, float, dict, list, File] -SYSTEM_VARIABLE_NODE_ID = "sys" -ENVIRONMENT_VARIABLE_NODE_ID = "env" -CONVERSATION_VARIABLE_NODE_ID = "conversation" +VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") class VariablePool(BaseModel): @@ -23,46 +26,63 @@ class VariablePool(BaseModel): # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. variable_dictionary: dict[str, dict[int, Segment]] = Field( - description="Variables mapping", default=defaultdict(dict) + description="Variables mapping", + default=defaultdict(dict), ) - # TODO: This user inputs is not used for pool. user_inputs: Mapping[str, Any] = Field( description="User inputs", ) - system_variables: Mapping[SystemVariableKey, Any] = Field( description="System variables", ) + environment_variables: Sequence[Variable] = Field( + description="Environment variables.", + default_factory=list, + ) + conversation_variables: Sequence[Variable] = Field( + description="Conversation variables.", + default_factory=list, + ) - environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list) + def __init__( + self, + *, + system_variables: Mapping[SystemVariableKey, Any] | None = None, + user_inputs: Mapping[str, Any] | None = None, + environment_variables: Sequence[Variable] | None = None, + conversation_variables: Sequence[Variable] | None = None, + **kwargs, + ): + environment_variables = environment_variables or [] + conversation_variables = conversation_variables or [] + user_inputs = user_inputs or {} + system_variables = system_variables or {} - conversation_variables: Sequence[Variable] | None = None + super().__init__( + system_variables=system_variables, + user_inputs=user_inputs, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + **kwargs, + ) - @model_validator(mode="after") - def val_model_after(self): - """ - Append system variables - :return: - """ - # Add system variables to the variable pool for key, value in self.system_variables.items(): self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) - # Add environment variables to the variable pool - for var in self.environment_variables or []: + for var in self.environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool - for var in self.conversation_variables or []: + for var in self.conversation_variables: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) - return self - def add(self, selector: Sequence[str], value: Any, /) -> None: """ Adds a variable to the variable pool. + NOTE: You should not add a non-Segment value to the variable pool + even if it is allowed now. + Args: selector (Sequence[str]): The selector for the variable. value (VariableValue): The value of the variable. @@ -82,7 +102,7 @@ class VariablePool(BaseModel): if isinstance(value, Segment): v = value else: - v = factory.build_segment(value) + v = variable_factory.build_segment(value) hash_key = hash(tuple(selector[1:])) self.variable_dictionary[selector[0]][hash_key] = v @@ -101,10 +121,19 @@ class VariablePool(BaseModel): ValueError: If the selector is invalid. """ if len(selector) < 2: - raise ValueError("Invalid selector") + return None + hash_key = hash(tuple(selector[1:])) value = self.variable_dictionary[selector[0]].get(hash_key) + if value is None: + selector, attr = selector[:-1], selector[-1] + value = self.get(selector) + if isinstance(value, FileSegment): + attr = FileAttribute(attr) + attr_value = file_manager.get_attr(file=value.value, attr=attr) + return variable_factory.build_segment(attr_value) + return value @deprecated("This method is deprecated, use `get` instead.") @@ -145,14 +174,18 @@ class VariablePool(BaseModel): hash_key = hash(tuple(selector[1:])) self.variable_dictionary[selector[0]].pop(hash_key, None) - def remove_node(self, node_id: str, /): - """ - Remove all variables associated with a given node id. + def convert_template(self, template: str, /): + parts = VARIABLE_PATTERN.split(template) + segments = [] + for part in filter(lambda x: x, parts): + if "." in part and (variable := self.get(part.split("."))): + segments.append(variable) + else: + segments.append(variable_factory.build_segment(part)) + return SegmentGroup(value=segments) - Args: - node_id (str): The node id to remove. - - Returns: - None - """ - self.variable_dictionary.pop(node_id, None) + def get_file(self, selector: Sequence[str], /) -> FileSegment | None: + segment = self.get(selector) + if isinstance(segment, FileSegment): + return segment + return None diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 0a1eb57de4..da56af1407 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -3,12 +3,13 @@ from typing import Optional from pydantic import BaseModel from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.base_node_data_entities import BaseIterationState -from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode, UserFrom +from core.workflow.nodes.base import BaseIterationState, BaseNode +from models.enums import UserFrom from models.workflow import Workflow, WorkflowType +from .node_entities import NodeRunResult +from .variable_pool import VariablePool + class WorkflowNodeAndResult: node: BaseNode diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index 07cbcd981e..bd4ccc1072 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base import BaseNode class WorkflowNodeRunFailedError(Exception): diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py index e69de29bb2..2fee3d7fad 100644 --- a/api/core/workflow/graph_engine/__init__.py +++ b/api/core/workflow/graph_engine/__init__.py @@ -0,0 +1,3 @@ +from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState + +__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index eda5fe079c..bc3a15bd00 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -18,11 +18,10 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): # process condition condition_processor = ConditionProcessor() - input_conditions, group_result = condition_processor.process_conditions( - variable_pool=graph_runtime_state.variable_pool, conditions=self.condition.conditions + _, _, final_result = condition_processor.process_conditions( + variable_pool=graph_runtime_state.variable_pool, + conditions=self.condition.conditions, + operator="and", ) - # Apply the logical operator for the current case - compare_result = all(group_result) - - return compare_result + return final_result diff --git a/api/core/workflow/graph_engine/entities/__init__.py b/api/core/workflow/graph_engine/entities/__init__.py index e69de29bb2..6331a0b723 100644 --- a/api/core/workflow/graph_engine/entities/__init__.py +++ b/api/core/workflow/graph_engine/entities/__init__.py @@ -0,0 +1,6 @@ +from .graph import Graph +from .graph_init_params import GraphInitParams +from .graph_runtime_state import GraphRuntimeState +from .runtime_route_state import RuntimeRouteState + +__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 06dc4cb8f4..86d89e0a32 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -3,9 +3,9 @@ from typing import Any, Optional from pydantic import BaseModel, Field -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes import NodeType +from core.workflow.nodes.base import BaseNodeData class GraphEngineEvent(BaseModel): diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 1175f4af2a..d87c039409 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -4,8 +4,8 @@ from typing import Any, Optional, cast from pydantic import BaseModel, Field -from core.workflow.entities.node_entities import NodeType from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter diff --git a/api/core/workflow/graph_engine/entities/graph_init_params.py b/api/core/workflow/graph_engine/entities/graph_init_params.py index 1a403f3e49..a0ecd824f4 100644 --- a/api/core/workflow/graph_engine/entities/graph_init_params.py +++ b/api/core/workflow/graph_engine/entities/graph_init_params.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, Field from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import UserFrom +from models.enums import UserFrom from models.workflow import WorkflowType diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 8342dbd13d..ada0b14ce4 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -10,11 +10,7 @@ from flask import Flask, current_app from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import ( - NodeRunMetadataKey, - NodeType, - UserFrom, -) +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( @@ -36,12 +32,14 @@ from core.workflow.graph_engine.entities.graph import Graph, GraphEdge from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.nodes.node_mapping import node_classes +from core.workflow.nodes.node_mapping import node_type_classes_mapping from extensions.ext_database import db +from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType logger = logging.getLogger(__name__) @@ -229,10 +227,8 @@ class GraphEngine: raise GraphRunFailedError(f"Node {node_id} config not found.") # convert to specific node - node_type = NodeType.value_of(node_config.get("data", {}).get("type")) - node_cls = node_classes.get(node_type) - if not node_cls: - raise GraphRunFailedError(f"Node {node_id} type {node_type} not found.") + node_type = NodeType(node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping[node_type] previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py index e69de29bb2..6101fcf9af 100644 --- a/api/core/workflow/nodes/__init__.py +++ b/api/core/workflow/nodes/__init__.py @@ -0,0 +1,3 @@ +from .enums import NodeType + +__all__ = ["NodeType"] diff --git a/api/core/workflow/nodes/answer/__init__.py b/api/core/workflow/nodes/answer/__init__.py index e69de29bb2..7a10f47eed 100644 --- a/api/core/workflow/nodes/answer/__init__.py +++ b/api/core/workflow/nodes/answer/__init__.py @@ -0,0 +1,4 @@ +from .answer_node import AnswerNode +from .entities import AnswerStreamGenerateRoute + +__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"] diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index deacbbbbb0..520cbdbb60 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,7 +1,8 @@ from collections.abc import Mapping, Sequence from typing import Any, cast -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.variables import ArrayFileSegment, FileSegment +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import ( AnswerNodeData, @@ -9,12 +10,13 @@ from core.workflow.nodes.answer.entities import ( TextGenerateRouteChunk, VarGenerateRouteChunk, ) -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser from models.workflow import WorkflowNodeExecutionStatus -class AnswerNode(BaseNode): +class AnswerNode(BaseNode[AnswerNodeData]): _node_data_cls = AnswerNodeData _node_type: NodeType = NodeType.ANSWER @@ -23,30 +25,35 @@ class AnswerNode(BaseNode): Run node :return: """ - node_data = self.node_data - node_data = cast(AnswerNodeData, node_data) - # generate routes - generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data) + generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data) answer = "" + files = [] for part in generate_routes: if part.type == GenerateRouteChunk.ChunkType.VAR: part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector - value = self.graph_runtime_state.variable_pool.get(value_selector) - - if value: - answer += value.markdown + variable = self.graph_runtime_state.variable_pool.get(value_selector) + if variable: + if isinstance(variable, FileSegment): + files.append(variable.value) + elif isinstance(variable, ArrayFileSegment): + files.extend(variable.value) + answer += variable.markdown else: part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer}) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files}) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: AnswerNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -55,9 +62,6 @@ class AnswerNode(BaseNode): :param node_data: node data :return: """ - node_data = node_data - node_data = cast(AnswerNodeData, node_data) - variable_template_parser = VariableTemplateParser(template=node_data.answer) variable_selectors = variable_template_parser.extract_variable_selectors() diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index bbd1f88867..bce28c5fcb 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -1,5 +1,4 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.answer.entities import ( AnswerNodeData, AnswerStreamGenerateRoute, @@ -7,6 +6,7 @@ from core.workflow.nodes.answer.entities import ( TextGenerateRouteChunk, VarGenerateRouteChunk, ) +from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 32dbf436ec..e3889941ca 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -1,8 +1,8 @@ import logging from collections.abc import Generator -from typing import Optional, cast +from typing import cast -from core.file.file_obj import FileVar +from core.file import FILE_MODEL_IDENTITY, File from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, @@ -203,7 +203,7 @@ class AnswerStreamProcessor(StreamProcessor): return files @classmethod - def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]: + def _get_file_var_from_value(cls, value: dict | list): """ Get file var from value :param value: variable value @@ -213,9 +213,9 @@ class AnswerStreamProcessor(StreamProcessor): return None if isinstance(value, dict): - if "__variant" in value and value["__variant"] == FileVar.__name__: + if "dify_model_identity" in value and value["dify_model_identity"] == FILE_MODEL_IDENTITY: return value - elif isinstance(value, FileVar): + elif isinstance(value, File): return value.to_dict() return None diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index e356e7fd70..e543d02dd7 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -2,7 +2,7 @@ from enum import Enum from pydantic import BaseModel, Field -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class AnswerNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/base/__init__.py b/api/core/workflow/nodes/base/__init__.py new file mode 100644 index 0000000000..61f727740c --- /dev/null +++ b/api/core/workflow/nodes/base/__init__.py @@ -0,0 +1,4 @@ +from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData +from .node import BaseNode + +__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"] diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/nodes/base/entities.py similarity index 100% rename from api/core/workflow/entities/base_node_data_entities.py rename to api/core/workflow/nodes/base/entities.py diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base/node.py similarity index 60% rename from api/core/workflow/nodes/base_node.py rename to api/core/workflow/nodes/base/node.py index 7bfe45a13c..053a339ba7 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,17 +1,27 @@ -from abc import ABC, abstractmethod +import logging +from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.graph_engine.entities.event import InNodeEvent -from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams -from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.nodes.event import RunCompletedEvent, RunEvent +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import NodeEvent, RunCompletedEvent +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import BaseNodeData + +if TYPE_CHECKING: + from core.workflow.graph_engine.entities.event import InNodeEvent + from core.workflow.graph_engine.entities.graph import Graph + from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams + from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState + +logger = logging.getLogger(__name__) + +GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) -class BaseNode(ABC): +class BaseNode(Generic[GenericNodeData]): _node_data_cls: type[BaseNodeData] _node_type: NodeType @@ -19,9 +29,9 @@ class BaseNode(ABC): self, id: str, config: Mapping[str, Any], - graph_init_params: GraphInitParams, - graph: Graph, - graph_runtime_state: GraphRuntimeState, + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", previous_node_id: Optional[str] = None, thread_pool_id: Optional[str] = None, ) -> None: @@ -45,22 +55,25 @@ class BaseNode(ABC): raise ValueError("Node ID is required.") self.node_id = node_id - self.node_data = self._node_data_cls(**config.get("data", {})) + self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {}))) @abstractmethod - def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: + def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: """ Run node :return: """ raise NotImplementedError - def run(self) -> Generator[RunEvent | InNodeEvent, None, None]: - """ - Run node entry - :return: - """ - result = self._run() + def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: + try: + result = self._run() + except Exception as e: + logger.error(f"Node {self.node_id} failed to run: {e}") + result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) if isinstance(result, NodeRunResult): yield RunCompletedEvent(run_result=result) @@ -69,7 +82,10 @@ class BaseNode(ABC): @classmethod def extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], config: dict + cls, + *, + graph_config: Mapping[str, Any], + config: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -83,12 +99,16 @@ class BaseNode(ABC): node_data = cls._node_data_cls(**config.get("data", {})) return cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id=node_id, node_data=node_data + graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: GenericNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/code/__init__.py b/api/core/workflow/nodes/code/__init__.py index e69de29bb2..8c6dcc7fcc 100644 --- a/api/core/workflow/nodes/code/__init__.py +++ b/api/core/workflow/nodes/code/__init__.py @@ -0,0 +1,3 @@ +from .code_node import CodeNode + +__all__ = ["CodeNode"] diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 9da7ad99f3..dd533ffc4c 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,18 +1,19 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus -class CodeNode(BaseNode): +class CodeNode(BaseNode[CodeNodeData]): _node_data_cls = CodeNodeData _node_type = NodeType.CODE @@ -33,20 +34,13 @@ class CodeNode(BaseNode): return code_provider.get_default_config() def _run(self) -> NodeRunResult: - """ - Run code - :return: - """ - node_data = self.node_data - node_data = cast(CodeNodeData, node_data) - # Get code language - code_language = node_data.code_language - code = node_data.code + code_language = self.node_data.code_language + code = self.node_data.code # Get variables variables = {} - for variable_selector in node_data.variables: + for variable_selector in self.node_data.variables: variable = variable_selector.variable value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) @@ -60,7 +54,7 @@ class CodeNode(BaseNode): ) # Transform result - result = self._transform_result(result, node_data.outputs) + result = self._transform_result(result, self.node_data.outputs) except (CodeExecutionError, ValueError) as e: return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) @@ -316,7 +310,11 @@ class CodeNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: CodeNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 5eb0e0f63f..e78183baf1 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -3,8 +3,8 @@ from typing import Literal, Optional from pydantic import BaseModel from core.helper.code_executor.code_executor import CodeLanguage -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class CodeNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/document_extractor/__init__.py b/api/core/workflow/nodes/document_extractor/__init__.py new file mode 100644 index 0000000000..3cc5fae187 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/__init__.py @@ -0,0 +1,4 @@ +from .entities import DocumentExtractorNodeData +from .node import DocumentExtractorNode + +__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"] diff --git a/api/core/workflow/nodes/document_extractor/entities.py b/api/core/workflow/nodes/document_extractor/entities.py new file mode 100644 index 0000000000..7e9ffaa889 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/entities.py @@ -0,0 +1,7 @@ +from collections.abc import Sequence + +from core.workflow.nodes.base import BaseNodeData + + +class DocumentExtractorNodeData(BaseNodeData): + variable_selector: Sequence[str] diff --git a/api/core/workflow/nodes/document_extractor/exc.py b/api/core/workflow/nodes/document_extractor/exc.py new file mode 100644 index 0000000000..c9d4bb8ef6 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/exc.py @@ -0,0 +1,14 @@ +class DocumentExtractorError(Exception): + """Base exception for errors related to the DocumentExtractorNode.""" + + +class FileDownloadError(DocumentExtractorError): + """Exception raised when there's an error downloading a file.""" + + +class UnsupportedFileTypeError(DocumentExtractorError): + """Exception raised when trying to extract text from an unsupported file type.""" + + +class TextExtractionError(DocumentExtractorError): + """Exception raised when there's an error during text extraction from a file.""" diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py new file mode 100644 index 0000000000..3efcc373b1 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -0,0 +1,244 @@ +import csv +import io + +import docx +import pandas as pd +import pypdfium2 +from unstructured.partition.email import partition_email +from unstructured.partition.epub import partition_epub +from unstructured.partition.msg import partition_msg +from unstructured.partition.ppt import partition_ppt +from unstructured.partition.pptx import partition_pptx + +from core.file import File, FileTransferMethod, file_manager +from core.helper import ssrf_proxy +from core.variables import ArrayFileSegment +from core.variables.segments import FileSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import DocumentExtractorNodeData +from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError + + +class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): + """ + Extracts text content from various file types. + Supports plain text, PDF, and DOC/DOCX files. + """ + + _node_data_cls = DocumentExtractorNodeData + _node_type = NodeType.DOCUMENT_EXTRACTOR + + def _run(self): + variable_selector = self.node_data.variable_selector + variable = self.graph_runtime_state.variable_pool.get(variable_selector) + + if variable is None: + error_message = f"File variable not found for selector: {variable_selector}" + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) + if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment): + error_message = f"Variable {variable_selector} is not an ArrayFileSegment" + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) + + value = variable.value + inputs = {"variable_selector": variable_selector} + process_data = {"documents": value if isinstance(value, list) else [value]} + + try: + if isinstance(value, list): + extracted_text_list = list(map(_extract_text_from_file, value)) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={"text": extracted_text_list}, + ) + elif isinstance(value, File): + extracted_text = _extract_text_from_file(value) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={"text": extracted_text}, + ) + else: + raise DocumentExtractorError(f"Unsupported variable type: {type(value)}") + except DocumentExtractorError as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=inputs, + process_data=process_data, + ) + + +def _extract_text(*, file_content: bytes, mime_type: str) -> str: + """Extract text from a file based on its MIME type.""" + if mime_type.startswith("text/plain") or mime_type in {"text/html", "text/htm", "text/markdown", "text/xml"}: + return _extract_text_from_plain_text(file_content) + elif mime_type == "application/pdf": + return _extract_text_from_pdf(file_content) + elif mime_type in { + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/msword", + }: + return _extract_text_from_doc(file_content) + elif mime_type == "text/csv": + return _extract_text_from_csv(file_content) + elif mime_type in { + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/vnd.ms-excel", + }: + return _extract_text_from_excel(file_content) + elif mime_type == "application/vnd.ms-powerpoint": + return _extract_text_from_ppt(file_content) + elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation": + return _extract_text_from_pptx(file_content) + elif mime_type == "application/epub+zip": + return _extract_text_from_epub(file_content) + elif mime_type == "message/rfc822": + return _extract_text_from_eml(file_content) + elif mime_type == "application/vnd.ms-outlook": + return _extract_text_from_msg(file_content) + else: + raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") + + +def _extract_text_from_plain_text(file_content: bytes) -> str: + try: + return file_content.decode("utf-8") + except UnicodeDecodeError as e: + raise TextExtractionError("Failed to decode plain text file") from e + + +def _extract_text_from_pdf(file_content: bytes) -> str: + try: + pdf_file = io.BytesIO(file_content) + pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True) + text = "" + for page in pdf_document: + text_page = page.get_textpage() + text += text_page.get_text_range() + text_page.close() + page.close() + return text + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e + + +def _extract_text_from_doc(file_content: bytes) -> str: + try: + doc_file = io.BytesIO(file_content) + doc = docx.Document(doc_file) + return "\n".join([paragraph.text for paragraph in doc.paragraphs]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e + + +def _download_file_content(file: File) -> bytes: + """Download the content of a file based on its transfer method.""" + try: + if file.transfer_method == FileTransferMethod.REMOTE_URL: + if file.remote_url is None: + raise FileDownloadError("Missing URL for remote file") + response = ssrf_proxy.get(file.remote_url) + response.raise_for_status() + return response.content + elif file.transfer_method == FileTransferMethod.LOCAL_FILE: + return file_manager.download(file) + else: + raise ValueError(f"Unsupported transfer method: {file.transfer_method}") + except Exception as e: + raise FileDownloadError(f"Error downloading file: {str(e)}") from e + + +def _extract_text_from_file(file: File): + if file.mime_type is None: + raise UnsupportedFileTypeError("Unable to determine file type: MIME type is missing") + file_content = _download_file_content(file) + extracted_text = _extract_text(file_content=file_content, mime_type=file.mime_type) + return extracted_text + + +def _extract_text_from_csv(file_content: bytes) -> str: + try: + csv_file = io.StringIO(file_content.decode("utf-8")) + csv_reader = csv.reader(csv_file) + rows = list(csv_reader) + + if not rows: + return "" + + # Create markdown table + markdown_table = "| " + " | ".join(rows[0]) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n" + for row in rows[1:]: + markdown_table += "| " + " | ".join(row) + " |\n" + + return markdown_table.strip() + except Exception as e: + raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e + + +def _extract_text_from_excel(file_content: bytes) -> str: + """Extract text from an Excel file using pandas.""" + + try: + df = pd.read_excel(io.BytesIO(file_content)) + + # Drop rows where all elements are NaN + df.dropna(how="all", inplace=True) + + # Convert DataFrame to markdown table + markdown_table = df.to_markdown(index=False) + return markdown_table + except Exception as e: + raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e + + +def _extract_text_from_ppt(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_ppt(file=file) + return "\n".join([getattr(element, "text", "") for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PPT: {str(e)}") from e + + +def _extract_text_from_pptx(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_pptx(file=file) + return "\n".join([getattr(element, "text", "") for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e + + +def _extract_text_from_epub(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_epub(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e + + +def _extract_text_from_eml(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_email(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e + + +def _extract_text_from_msg(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_msg(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e diff --git a/api/core/workflow/nodes/end/__init__.py b/api/core/workflow/nodes/end/__init__.py index e69de29bb2..adb381701c 100644 --- a/api/core/workflow/nodes/end/__init__.py +++ b/api/core/workflow/nodes/end/__init__.py @@ -0,0 +1,4 @@ +from .end_node import EndNode +from .entities import EndStreamParam + +__all__ = ["EndStreamParam", "EndNode"] diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 7b78d67be8..2398e4e89d 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,13 +1,14 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus -class EndNode(BaseNode): +class EndNode(BaseNode[EndNodeData]): _node_data_cls = EndNodeData _node_type = NodeType.END @@ -16,20 +17,27 @@ class EndNode(BaseNode): Run node :return: """ - node_data = self.node_data - node_data = cast(EndNodeData, node_data) - output_variables = node_data.outputs + output_variables = self.node_data.outputs outputs = {} for variable_selector in output_variables: - value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + value = variable.to_object() if variable is not None else None outputs[variable_selector.variable] = value - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=outputs, + outputs=outputs, + ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: EndNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index 9a7d2ecde3..ea8b6b5042 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -1,5 +1,5 @@ -from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam +from core.workflow.nodes.enums import NodeType class EndStreamGeneratorRouter: diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index c3270ac22a..c16e85b0eb 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class EndNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py new file mode 100644 index 0000000000..208144655b --- /dev/null +++ b/api/core/workflow/nodes/enums.py @@ -0,0 +1,24 @@ +from enum import Enum + + +class NodeType(str, Enum): + START = "start" + END = "end" + ANSWER = "answer" + LLM = "llm" + KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + IF_ELSE = "if-else" + CODE = "code" + TEMPLATE_TRANSFORM = "template-transform" + QUESTION_CLASSIFIER = "question-classifier" + HTTP_REQUEST = "http-request" + TOOL = "tool" + VARIABLE_AGGREGATOR = "variable-aggregator" + VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. + LOOP = "loop" + ITERATION = "iteration" + ITERATION_START = "iteration-start" # Fake start node for iteration. + PARAMETER_EXTRACTOR = "parameter-extractor" + CONVERSATION_VARIABLE_ASSIGNER = "assigner" + DOCUMENT_EXTRACTOR = "document-extractor" + LIST_OPERATOR = "list-operator" diff --git a/api/core/workflow/nodes/event/__init__.py b/api/core/workflow/nodes/event/__init__.py new file mode 100644 index 0000000000..581def9553 --- /dev/null +++ b/api/core/workflow/nodes/event/__init__.py @@ -0,0 +1,10 @@ +from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from .types import NodeEvent + +__all__ = [ + "RunCompletedEvent", + "RunRetrieverResourceEvent", + "RunStreamChunkEvent", + "NodeEvent", + "ModelInvokeCompletedEvent", +] diff --git a/api/core/workflow/nodes/event.py b/api/core/workflow/nodes/event/event.py similarity index 72% rename from api/core/workflow/nodes/event.py rename to api/core/workflow/nodes/event/event.py index 276c13a6d4..b7034561bf 100644 --- a/api/core/workflow/nodes/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -1,5 +1,6 @@ from pydantic import BaseModel, Field +from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.node_entities import NodeRunResult @@ -17,4 +18,11 @@ class RunRetrieverResourceEvent(BaseModel): context: str = Field(..., description="context") -RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent +class ModelInvokeCompletedEvent(BaseModel): + """ + Model invoke completed + """ + + text: str + usage: LLMUsage + finish_reason: str | None = None diff --git a/api/core/workflow/nodes/event/types.py b/api/core/workflow/nodes/event/types.py new file mode 100644 index 0000000000..b19a91022d --- /dev/null +++ b/api/core/workflow/nodes/event/types.py @@ -0,0 +1,3 @@ +from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent + +NodeEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent | ModelInvokeCompletedEvent diff --git a/api/core/workflow/nodes/http_request/__init__.py b/api/core/workflow/nodes/http_request/__init__.py index e69de29bb2..9408c2dde0 100644 --- a/api/core/workflow/nodes/http_request/__init__.py +++ b/api/core/workflow/nodes/http_request/__init__.py @@ -0,0 +1,4 @@ +from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData +from .node import HttpRequestNode + +__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"] diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 66dd1f2dc6..816ece9577 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,15 +1,25 @@ -from typing import Literal, Optional, Union +from collections.abc import Sequence +from typing import Literal, Optional -from pydantic import BaseModel, ValidationInfo, field_validator +import httpx +from pydantic import BaseModel, Field, ValidationInfo, field_validator from configs import dify_config -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData + +NON_FILE_CONTENT_TYPES = ( + "application/json", + "application/xml", + "text/html", + "text/plain", + "application/x-www-form-urlencoded", +) class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal[None, "basic", "bearer", "custom"] - api_key: Union[None, str] = None - header: Union[None, str] = None + type: Literal["basic", "bearer", "custom"] + api_key: str + header: str = "" class HttpRequestNodeAuthorization(BaseModel): @@ -31,9 +41,16 @@ class HttpRequestNodeAuthorization(BaseModel): return v +class BodyData(BaseModel): + key: str = "" + type: Literal["file", "text"] + value: str = "" + file: Sequence[str] = Field(default_factory=list) + + class HttpRequestNodeBody(BaseModel): - type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json"] - data: Union[None, str] = None + type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"] + data: Sequence[BodyData] = Field(default_factory=list) class HttpRequestNodeTimeout(BaseModel): @@ -54,3 +71,51 @@ class HttpRequestNodeData(BaseNodeData): params: str body: Optional[HttpRequestNodeBody] = None timeout: Optional[HttpRequestNodeTimeout] = None + + +class Response: + headers: dict[str, str] + response: httpx.Response + + def __init__(self, response: httpx.Response): + self.response = response + self.headers = dict(response.headers) + + @property + def is_file(self): + content_type = self.content_type + content_disposition = self.response.headers.get("Content-Disposition", "") + + return "attachment" in content_disposition or ( + not any(non_file in content_type for non_file in NON_FILE_CONTENT_TYPES) + and any(file_type in content_type for file_type in ("application/", "image/", "audio/", "video/")) + ) + + @property + def content_type(self) -> str: + return self.headers.get("Content-Type", "") + + @property + def text(self) -> str: + return self.response.text + + @property + def content(self) -> bytes: + return self.response.content + + @property + def status_code(self) -> int: + return self.response.status_code + + @property + def size(self) -> int: + return len(self.content) + + @property + def readable_size(self) -> str: + if self.size < 1024: + return f"{self.size} bytes" + elif self.size < 1024 * 1024: + return f"{(self.size / 1024):.2f} KB" + else: + return f"{(self.size / 1024 / 1024):.2f} MB" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py new file mode 100644 index 0000000000..71bb0ac86a --- /dev/null +++ b/api/core/workflow/nodes/http_request/executor.py @@ -0,0 +1,327 @@ +import json +from collections.abc import Mapping, Sequence +from copy import deepcopy +from random import randint +from typing import Any, Literal +from urllib.parse import urlencode, urlparse + +import httpx + +from configs import dify_config +from core.file import file_manager +from core.helper import ssrf_proxy +from core.workflow.entities.variable_pool import VariablePool + +from .entities import ( + HttpRequestNodeAuthorization, + HttpRequestNodeData, + HttpRequestNodeTimeout, + Response, +) + +BODY_TYPE_TO_CONTENT_TYPE = { + "json": "application/json", + "x-www-form-urlencoded": "application/x-www-form-urlencoded", + "form-data": "multipart/form-data", + "raw-text": "text/plain", +} + + +class Executor: + method: Literal["get", "head", "post", "put", "delete", "patch"] + url: str + params: Mapping[str, str] | None + content: str | bytes | None + data: Mapping[str, Any] | None + files: Mapping[str, bytes] | None + json: Any + headers: dict[str, str] + auth: HttpRequestNodeAuthorization + timeout: HttpRequestNodeTimeout + + boundary: str + + def __init__( + self, + *, + node_data: HttpRequestNodeData, + timeout: HttpRequestNodeTimeout, + variable_pool: VariablePool, + ): + # If authorization API key is present, convert the API key using the variable pool + if node_data.authorization.type == "api-key": + if node_data.authorization.config is None: + raise ValueError("authorization config is required") + node_data.authorization.config.api_key = variable_pool.convert_template( + node_data.authorization.config.api_key + ).text + + self.url: str = node_data.url + self.method = node_data.method + self.auth = node_data.authorization + self.timeout = timeout + self.params = None + self.headers = {} + self.content = None + self.files = None + self.data = None + self.json = None + + # init template + self.variable_pool = variable_pool + self.node_data = node_data + self._initialize() + + def _initialize(self): + self._init_url() + self._init_params() + self._init_headers() + self._init_body() + + def _init_url(self): + self.url = self.variable_pool.convert_template(self.node_data.url).text + + def _init_params(self): + params = self.variable_pool.convert_template(self.node_data.params).text + self.params = _plain_text_to_dict(params) + + def _init_headers(self): + headers = self.variable_pool.convert_template(self.node_data.headers).text + self.headers = _plain_text_to_dict(headers) + + body = self.node_data.body + if body is None: + return + if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE: + self.headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] + if body.type == "form-data": + self.boundary = f"----WebKitFormBoundary{_generate_random_string(16)}" + self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" + + def _init_body(self): + body = self.node_data.body + if body is not None: + data = body.data + match body.type: + case "none": + self.content = "" + case "raw-text": + self.content = self.variable_pool.convert_template(data[0].value).text + case "json": + json_object = json.loads(data[0].value) + self.json = self._parse_object_contains_variables(json_object) + case "binary": + file_selector = data[0].file + file_variable = self.variable_pool.get_file(file_selector) + if file_variable is None: + raise ValueError(f"cannot fetch file with selector {file_selector}") + file = file_variable.value + self.content = file_manager.download(file) + case "x-www-form-urlencoded": + form_data = { + self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( + item.value + ).text + for item in data + } + self.data = form_data + case "form-data": + form_data = { + self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( + item.value + ).text + for item in filter(lambda item: item.type == "text", data) + } + file_selectors = { + self.variable_pool.convert_template(item.key).text: item.file + for item in filter(lambda item: item.type == "file", data) + } + files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()} + files = {k: v for k, v in files.items() if v is not None} + files = {k: variable.value for k, variable in files.items()} + files = {k: file_manager.download(v) for k, v in files.items() if v.related_id is not None} + + self.data = form_data + self.files = files + + def _assembling_headers(self) -> dict[str, Any]: + authorization = deepcopy(self.auth) + headers = deepcopy(self.headers) or {} + if self.auth.type == "api-key": + if self.auth.config is None: + raise ValueError("self.authorization config is required") + if authorization.config is None: + raise ValueError("authorization config is required") + + if self.auth.config.api_key is None: + raise ValueError("api_key is required") + + if not authorization.config.header: + authorization.config.header = "Authorization" + + if self.auth.config.type == "bearer": + headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" + elif self.auth.config.type == "basic": + headers[authorization.config.header] = f"Basic {authorization.config.api_key}" + elif self.auth.config.type == "custom": + headers[authorization.config.header] = authorization.config.api_key or "" + + return headers + + def _validate_and_parse_response(self, response: httpx.Response) -> Response: + executor_response = Response(response) + + threshold_size = ( + dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE + if executor_response.is_file + else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + ) + if executor_response.size > threshold_size: + raise ValueError( + f'{"File" if executor_response.is_file else "Text"} size is too large,' + f' max size is {threshold_size / 1024 / 1024:.2f} MB,' + f' but current size is {executor_response.readable_size}.' + ) + + return executor_response + + def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: + """ + do http request depending on api bundle + """ + if self.method not in {"get", "head", "post", "put", "delete", "patch"}: + raise ValueError(f"Invalid http method {self.method}") + + request_args = { + "url": self.url, + "data": self.data, + "files": self.files, + "json": self.json, + "content": self.content, + "headers": headers, + "params": self.params, + "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), + "follow_redirects": True, + } + + response = getattr(ssrf_proxy, self.method)(**request_args) + return response + + def invoke(self) -> Response: + # assemble headers + headers = self._assembling_headers() + # do http request + response = self._do_http_request(headers) + # validate response + return self._validate_and_parse_response(response) + + def to_log(self): + url_parts = urlparse(self.url) + path = url_parts.path or "/" + + # Add query parameters + if self.params: + query_string = urlencode(self.params) + path += f"?{query_string}" + elif url_parts.query: + path += f"?{url_parts.query}" + + raw = f"{self.method.upper()} {path} HTTP/1.1\r\n" + raw += f"Host: {url_parts.netloc}\r\n" + + headers = self._assembling_headers() + for k, v in headers.items(): + if self.auth.type == "api-key": + authorization_header = "Authorization" + if self.auth.config and self.auth.config.header: + authorization_header = self.auth.config.header + if k.lower() == authorization_header.lower(): + raw += f'{k}: {"*" * len(v)}\r\n' + continue + raw += f"{k}: {v}\r\n" + + body = "" + if self.files: + boundary = self.boundary + for k, v in self.files.items(): + body += f"--{boundary}\r\n" + body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' + body += f"{v[1]}\r\n" + body += f"--{boundary}--\r\n" + elif self.node_data.body: + if self.content: + if isinstance(self.content, str): + body = self.content + elif isinstance(self.content, bytes): + body = self.content.decode("utf-8", errors="replace") + elif self.data and self.node_data.body.type == "x-www-form-urlencoded": + body = urlencode(self.data) + elif self.data and self.node_data.body.type == "form-data": + boundary = self.boundary + for key, value in self.data.items(): + body += f"--{boundary}\r\n" + body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' + body += f"{value}\r\n" + body += f"--{boundary}--\r\n" + elif self.json: + body = json.dumps(self.json) + elif self.node_data.body.type == "raw-text": + body = self.node_data.body.data[0].value + if body: + raw += f"Content-Length: {len(body)}\r\n" + raw += "\r\n" # Empty line between headers and body + raw += body + + return raw + + def _parse_object_contains_variables(self, obj: str | dict | list, /) -> Mapping[str, Any] | Sequence[Any] | str: + if isinstance(obj, dict): + return {k: self._parse_object_contains_variables(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._parse_object_contains_variables(v) for v in obj] + elif isinstance(obj, str): + return self.variable_pool.convert_template(obj).text + + +def _plain_text_to_dict(text: str, /) -> dict[str, str]: + """ + Convert a string of key-value pairs to a dictionary. + + Each line in the input string represents a key-value pair. + Keys and values are separated by ':'. + Empty values are allowed. + + Examples: + 'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'} + 'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'} + 'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'} + + Args: + convert_text (str): The input string to convert. + + Returns: + dict[str, str]: A dictionary of key-value pairs. + """ + return { + key.strip(): (value[0].strip() if value else "") + for line in text.splitlines() + if line.strip() + for key, *value in [line.split(":", 1)] + } + + +def _generate_random_string(n: int) -> str: + """ + Generate a random string of lowercase ASCII letters. + + Args: + n (int): The length of the random string to generate. + + Returns: + str: A random string of lowercase ASCII letters with length n. + + Example: + >>> _generate_random_string(5) + 'abcde' + """ + return "".join([chr(randint(97, 122)) for _ in range(n)]) diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py deleted file mode 100644 index f8ab4e3132..0000000000 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ /dev/null @@ -1,343 +0,0 @@ -import json -from copy import deepcopy -from random import randint -from typing import Any, Optional, Union -from urllib.parse import urlencode - -import httpx - -from configs import dify_config -from core.helper import ssrf_proxy -from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.http_request.entities import ( - HttpRequestNodeAuthorization, - HttpRequestNodeBody, - HttpRequestNodeData, - HttpRequestNodeTimeout, -) -from core.workflow.utils.variable_template_parser import VariableTemplateParser - - -class HttpExecutorResponse: - headers: dict[str, str] - response: httpx.Response - - def __init__(self, response: httpx.Response): - self.response = response - self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {} - - @property - def is_file(self) -> bool: - """ - check if response is file - """ - content_type = self.get_content_type() - file_content_types = ["image", "audio", "video"] - - return any(v in content_type for v in file_content_types) - - def get_content_type(self) -> str: - return self.headers.get("content-type", "") - - def extract_file(self) -> tuple[str, bytes]: - """ - extract file from response if content type is file related - """ - if self.is_file: - return self.get_content_type(), self.body - - return "", b"" - - @property - def content(self) -> str: - if isinstance(self.response, httpx.Response): - return self.response.text - else: - raise ValueError(f"Invalid response type {type(self.response)}") - - @property - def body(self) -> bytes: - if isinstance(self.response, httpx.Response): - return self.response.content - else: - raise ValueError(f"Invalid response type {type(self.response)}") - - @property - def status_code(self) -> int: - if isinstance(self.response, httpx.Response): - return self.response.status_code - else: - raise ValueError(f"Invalid response type {type(self.response)}") - - @property - def size(self) -> int: - return len(self.body) - - @property - def readable_size(self) -> str: - if self.size < 1024: - return f"{self.size} bytes" - elif self.size < 1024 * 1024: - return f"{(self.size / 1024):.2f} KB" - else: - return f"{(self.size / 1024 / 1024):.2f} MB" - - -class HttpExecutor: - server_url: str - method: str - authorization: HttpRequestNodeAuthorization - params: dict[str, Any] - headers: dict[str, Any] - body: Union[None, str] - files: Union[None, dict[str, Any]] - boundary: str - variable_selectors: list[VariableSelector] - timeout: HttpRequestNodeTimeout - - def __init__( - self, - node_data: HttpRequestNodeData, - timeout: HttpRequestNodeTimeout, - variable_pool: Optional[VariablePool] = None, - ): - self.server_url = node_data.url - self.method = node_data.method - self.authorization = node_data.authorization - self.timeout = timeout - self.params = {} - self.headers = {} - self.body = None - self.files = None - - # init template - self.variable_selectors = [] - self._init_template(node_data, variable_pool) - - @staticmethod - def _is_json_body(body: HttpRequestNodeBody): - """ - check if body is json - """ - if body and body.type == "json" and body.data: - try: - json.loads(body.data) - return True - except: - return False - - return False - - @staticmethod - def _to_dict(convert_text: str): - """ - Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` - """ - kv_paris = convert_text.split("\n") - result = {} - for kv in kv_paris: - if not kv.strip(): - continue - - kv = kv.split(":", maxsplit=1) - if len(kv) == 1: - k, v = kv[0], "" - else: - k, v = kv - result[k.strip()] = v - return result - - def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None): - # extract all template in url - self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool) - - # extract all template in params - params, params_variable_selectors = self._format_template(node_data.params, variable_pool) - self.params = self._to_dict(params) - - # extract all template in headers - headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool) - self.headers = self._to_dict(headers) - - # extract all template in body - body_data_variable_selectors = [] - if node_data.body: - # check if it's a valid JSON - is_valid_json = self._is_json_body(node_data.body) - - body_data = node_data.body.data or "" - if body_data: - body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json) - - content_type_is_set = any(key.lower() == "content-type" for key in self.headers) - if node_data.body.type == "json" and not content_type_is_set: - self.headers["Content-Type"] = "application/json" - elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: - self.headers["Content-Type"] = "application/x-www-form-urlencoded" - - if node_data.body.type in {"form-data", "x-www-form-urlencoded"}: - body = self._to_dict(body_data) - - if node_data.body.type == "form-data": - self.files = {k: ("", v) for k, v in body.items()} - random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)]) - self.boundary = f"----WebKitFormBoundary{random_str(16)}" - - self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" - else: - self.body = urlencode(body) - elif node_data.body.type in {"json", "raw-text"}: - self.body = body_data - elif node_data.body.type == "none": - self.body = "" - - self.variable_selectors = ( - server_url_variable_selectors - + params_variable_selectors - + headers_variable_selectors - + body_data_variable_selectors - ) - - def _assembling_headers(self) -> dict[str, Any]: - authorization = deepcopy(self.authorization) - headers = deepcopy(self.headers) or {} - if self.authorization.type == "api-key": - if self.authorization.config is None: - raise ValueError("self.authorization config is required") - if authorization.config is None: - raise ValueError("authorization config is required") - - if self.authorization.config.api_key is None: - raise ValueError("api_key is required") - - if not authorization.config.header: - authorization.config.header = "Authorization" - - if self.authorization.config.type == "bearer": - headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" - elif self.authorization.config.type == "basic": - headers[authorization.config.header] = f"Basic {authorization.config.api_key}" - elif self.authorization.config.type == "custom": - headers[authorization.config.header] = authorization.config.api_key - - return headers - - def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse: - """ - validate the response - """ - if isinstance(response, httpx.Response): - executor_response = HttpExecutorResponse(response) - else: - raise ValueError(f"Invalid response type {type(response)}") - - threshold_size = ( - dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE - if executor_response.is_file - else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE - ) - if executor_response.size > threshold_size: - raise ValueError( - f'{"File" if executor_response.is_file else "Text"} size is too large,' - f' max size is {threshold_size / 1024 / 1024:.2f} MB,' - f' but current size is {executor_response.readable_size}.' - ) - - return executor_response - - def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: - """ - do http request depending on api bundle - """ - kwargs = { - "url": self.server_url, - "headers": headers, - "params": self.params, - "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), - "follow_redirects": True, - } - - if self.method in {"get", "head", "post", "put", "delete", "patch"}: - response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) - else: - raise ValueError(f"Invalid http method {self.method}") - return response - - def invoke(self) -> HttpExecutorResponse: - """ - invoke http request - """ - # assemble headers - headers = self._assembling_headers() - - # do http request - response = self._do_http_request(headers) - - # validate response - return self._validate_and_parse_response(response) - - def to_raw_request(self) -> str: - """ - convert to raw request - """ - server_url = self.server_url - if self.params: - server_url += f"?{urlencode(self.params)}" - - raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n" - - headers = self._assembling_headers() - for k, v in headers.items(): - # get authorization header - if self.authorization.type == "api-key": - authorization_header = "Authorization" - if self.authorization.config and self.authorization.config.header: - authorization_header = self.authorization.config.header - - if k.lower() == authorization_header.lower(): - raw_request += f'{k}: {"*" * len(v)}\n' - continue - - raw_request += f"{k}: {v}\n" - - raw_request += "\n" - - # if files, use multipart/form-data with boundary - if self.files: - boundary = self.boundary - raw_request += f"--{boundary}" - for k, v in self.files.items(): - raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n' - raw_request += f"{v[1]}\n" - raw_request += f"--{boundary}" - raw_request += "--" - else: - raw_request += self.body or "" - - return raw_request - - def _format_template( - self, template: str, variable_pool: Optional[VariablePool], escape_quotes: bool = False - ) -> tuple[str, list[VariableSelector]]: - """ - format template - """ - variable_template_parser = VariableTemplateParser(template=template) - variable_selectors = variable_template_parser.extract_variable_selectors() - - if variable_pool: - variable_value_mapping = {} - for variable_selector in variable_selectors: - variable = variable_pool.get_any(variable_selector.value_selector) - if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") - if escape_quotes and isinstance(variable, str): - value = variable.replace('"', '\\"').replace("\n", "\\n") - else: - value = variable - variable_value_mapping[variable_selector.variable] = value - - return variable_template_parser.format(variable_value_mapping), variable_selectors - else: - return template, variable_selectors diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py deleted file mode 100644 index cd40819126..0000000000 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ /dev/null @@ -1,165 +0,0 @@ -import logging -from collections.abc import Mapping, Sequence -from mimetypes import guess_extension -from os import path -from typing import Any, cast - -from configs import dify_config -from core.app.segments import parser -from core.file.file_obj import FileTransferMethod, FileType, FileVar -from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.http_request.entities import ( - HttpRequestNodeData, - HttpRequestNodeTimeout, -) -from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse -from models.workflow import WorkflowNodeExecutionStatus - -HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( - connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, - read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, - write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, -) - - -class HttpRequestNode(BaseNode): - _node_data_cls = HttpRequestNodeData - _node_type = NodeType.HTTP_REQUEST - - @classmethod - def get_default_config(cls, filters: dict | None = None) -> dict: - return { - "type": "http-request", - "config": { - "method": "get", - "authorization": { - "type": "no-auth", - }, - "body": {"type": "none"}, - "timeout": { - **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), - "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, - "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, - "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, - }, - }, - } - - def _run(self) -> NodeRunResult: - node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data) - # TODO: Switch to use segment directly - if node_data.authorization.config and node_data.authorization.config.api_key: - node_data.authorization.config.api_key = parser.convert_template( - template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool - ).text - - # init http executor - http_executor = None - try: - http_executor = HttpExecutor( - node_data=node_data, - timeout=self._get_request_timeout(node_data), - variable_pool=self.graph_runtime_state.variable_pool, - ) - - # invoke http executor - response = http_executor.invoke() - except Exception as e: - process_data = {} - if http_executor: - process_data = { - "request": http_executor.to_raw_request(), - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - process_data=process_data, - ) - - files = self.extract_files(http_executor.server_url, response) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "status_code": response.status_code, - "body": response.content if not files else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_raw_request(), - }, - ) - - @staticmethod - def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: - timeout = node_data.timeout - if timeout is None: - return HTTP_REQUEST_DEFAULT_TIMEOUT - - timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect - timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read - timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write - return timeout - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - try: - http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT) - - variable_selectors = http_executor.variable_selectors - - variable_mapping = {} - for variable_selector in variable_selectors: - variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector - - return variable_mapping - except Exception as e: - logging.exception(f"Failed to extract variable selector to variable mapping: {e}") - return {} - - def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: - """ - Extract files from response - """ - files = [] - mimetype, file_binary = response.extract_file() - - if mimetype: - # extract filename from url - filename = path.basename(url) - # extract extension if possible - extension = guess_extension(mimetype) or ".bin" - - tool_file = ToolFileManager.create_file_by_raw( - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=None, - file_binary=file_binary, - mimetype=mimetype, - ) - - files.append( - FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file.id, - filename=filename, - extension=extension, - mime_type=mimetype, - ) - ) - - return files diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py new file mode 100644 index 0000000000..483d0e2b7e --- /dev/null +++ b/api/core/workflow/nodes/http_request/node.py @@ -0,0 +1,174 @@ +import logging +from collections.abc import Mapping, Sequence +from mimetypes import guess_extension +from os import path +from typing import Any + +from configs import dify_config +from core.file import File, FileTransferMethod, FileType +from core.tools.tool_file_manager import ToolFileManager +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.http_request.executor import Executor +from core.workflow.utils import variable_template_parser +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ( + HttpRequestNodeData, + HttpRequestNodeTimeout, + Response, +) + +HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( + connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, +) + +logger = logging.getLogger(__name__) + + +class HttpRequestNode(BaseNode[HttpRequestNodeData]): + _node_data_cls = HttpRequestNodeData + _node_type = NodeType.HTTP_REQUEST + + @classmethod + def get_default_config(cls, filters: dict | None = None) -> dict: + return { + "type": "http-request", + "config": { + "method": "get", + "authorization": { + "type": "no-auth", + }, + "body": {"type": "none"}, + "timeout": { + **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), + "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, + }, + }, + } + + def _run(self) -> NodeRunResult: + process_data = {} + try: + http_executor = Executor( + node_data=self.node_data, + timeout=self._get_request_timeout(self.node_data), + variable_pool=self.graph_runtime_state.variable_pool, + ) + process_data["request"] = http_executor.to_log() + + response = http_executor.invoke() + files = self.extract_files(url=http_executor.url, response=response) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "status_code": response.status_code, + "body": response.text if not files else "", + "headers": response.headers, + "files": files, + }, + process_data={ + "request": http_executor.to_log(), + }, + ) + except Exception as e: + logger.warning(f"http request node {self.node_id} failed to run: {e}") + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + process_data=process_data, + ) + + @staticmethod + def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: + timeout = node_data.timeout + if timeout is None: + return HTTP_REQUEST_DEFAULT_TIMEOUT + + timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect + timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read + timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write + return timeout + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: HttpRequestNodeData, + ) -> Mapping[str, Sequence[str]]: + selectors: list[VariableSelector] = [] + selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) + selectors += variable_template_parser.extract_selectors_from_template(node_data.params) + if node_data.body: + body_type = node_data.body.type + data = node_data.body.data + match body_type: + case "binary": + selector = data[0].file + selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) + case "json" | "raw-text": + selectors += variable_template_parser.extract_selectors_from_template(data[0].key) + selectors += variable_template_parser.extract_selectors_from_template(data[0].value) + case "x-www-form-urlencoded": + for item in data: + selectors += variable_template_parser.extract_selectors_from_template(item.key) + selectors += variable_template_parser.extract_selectors_from_template(item.value) + case "form-data": + for item in data: + selectors += variable_template_parser.extract_selectors_from_template(item.key) + if item.type == "text": + selectors += variable_template_parser.extract_selectors_from_template(item.value) + elif item.type == "file": + selectors.append( + VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file) + ) + + mapping = {} + for selector in selectors: + mapping[node_id + "." + selector.variable] = selector.value_selector + + return mapping + + def extract_files(self, url: str, response: Response) -> list[File]: + """ + Extract files from response + """ + files = [] + content_type = response.content_type + content = response.content + + if content_type: + # extract filename from url + filename = path.basename(url) + # extract extension if possible + extension = guess_extension(content_type) or ".bin" + + tool_file = ToolFileManager.create_file_by_raw( + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + file_binary=content, + mimetype=content_type, + ) + + files.append( + File( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=tool_file.id, + filename=filename, + extension=extension, + mime_type=content_type, + ) + ) + + return files diff --git a/api/core/workflow/nodes/if_else/__init__.py b/api/core/workflow/nodes/if_else/__init__.py index e69de29bb2..afa0e8112c 100644 --- a/api/core/workflow/nodes/if_else/__init__.py +++ b/api/core/workflow/nodes/if_else/__init__.py @@ -0,0 +1,3 @@ +from .if_else_node import IfElseNode + +__all__ = ["IfElseNode"] diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 54c1081fd3..23f5d2cc31 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -1,8 +1,8 @@ from typing import Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData from core.workflow.utils.condition.entities import Condition @@ -21,6 +21,6 @@ class IfElseNodeData(BaseNodeData): conditions: list[Condition] logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = None + conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) cases: Optional[list[Case]] = None diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 37384202d8..6960fc045a 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,14 +1,19 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Literal -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from typing_extensions import deprecated + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.if_else.entities import IfElseNodeData +from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.processor import ConditionProcessor from models.workflow import WorkflowNodeExecutionStatus -class IfElseNode(BaseNode): +class IfElseNode(BaseNode[IfElseNodeData]): _node_data_cls = IfElseNodeData _node_type = NodeType.IF_ELSE @@ -17,9 +22,6 @@ class IfElseNode(BaseNode): Run node :return: """ - node_data = self.node_data - node_data = cast(IfElseNodeData, node_data) - node_inputs: dict[str, list] = {"conditions": []} process_datas: dict[str, list] = {"condition_results": []} @@ -30,15 +32,14 @@ class IfElseNode(BaseNode): condition_processor = ConditionProcessor() try: # Check if the new cases structure is used - if node_data.cases: - for case in node_data.cases: - input_conditions, group_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions + if self.node_data.cases: + for case in self.node_data.cases: + input_conditions, group_result, final_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=case.conditions, + operator=case.logical_operator, ) - # Apply the logical operator for the current case - final_result = all(group_result) if case.logical_operator == "and" else any(group_result) - process_datas["condition_results"].append( { "group": case.model_dump(), @@ -53,13 +54,15 @@ class IfElseNode(BaseNode): break else: + # TODO: Update database then remove this # Fallback to old structure if cases are not defined - input_conditions, group_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, conditions=node_data.conditions + input_conditions, group_result, final_result = _should_not_use_old_function( + condition_processor=condition_processor, + variable_pool=self.graph_runtime_state.variable_pool, + conditions=self.node_data.conditions or [], + operator=self.node_data.logical_operator or "and", ) - final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result) - selected_case_id = "true" if final_result else "false" process_datas["condition_results"].append( @@ -87,7 +90,11 @@ class IfElseNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IfElseNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -97,3 +104,18 @@ class IfElseNode(BaseNode): :return: """ return {} + + +@deprecated("This function is deprecated. You should use the new cases structure.") +def _should_not_use_old_function( + *, + condition_processor: ConditionProcessor, + variable_pool: VariablePool, + conditions: list[Condition], + operator: Literal["and", "or"], +): + return condition_processor.process_conditions( + variable_pool=variable_pool, + conditions=conditions, + operator=operator, + ) diff --git a/api/core/workflow/nodes/iteration/__init__.py b/api/core/workflow/nodes/iteration/__init__.py index e69de29bb2..5bb87aaffa 100644 --- a/api/core/workflow/nodes/iteration/__init__.py +++ b/api/core/workflow/nodes/iteration/__init__.py @@ -0,0 +1,5 @@ +from .entities import IterationNodeData +from .iteration_node import IterationNode +from .iteration_start_node import IterationStartNode + +__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"] diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 3c2c189159..4afc870e50 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,6 +1,8 @@ from typing import Any, Optional -from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData +from pydantic import Field + +from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData class IterationNodeData(BaseIterationNodeData): @@ -26,7 +28,7 @@ class IterationState(BaseIterationState): Iteration State. """ - outputs: list[Any] = None + outputs: list[Any] = Field(default_factory=list) current_output: Optional[Any] = None class MetaData(BaseIterationState.MetaData): diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 01bb4e9076..b28ae0a85c 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -5,7 +5,7 @@ from typing import Any, cast from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, BaseNodeEvent, @@ -20,15 +20,16 @@ from core.workflow.graph_engine.entities.event import ( NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.event import RunCompletedEvent, RunEvent +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import IterationNodeData from models.workflow import WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) -class IterationNode(BaseNode): +class IterationNode(BaseNode[IterationNodeData]): """ Iteration Node. """ @@ -36,11 +37,10 @@ class IterationNode(BaseNode): _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION - def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. """ - self.node_data = cast(IterationNodeData, self.node_data) iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) if not iterator_list_segment: @@ -177,7 +177,7 @@ class IterationNode(BaseNode): # remove all nodes outputs from variable pool for node_id in iteration_graph.node_ids: - variable_pool.remove_node(node_id) + variable_pool.remove([node_id]) # move to next iteration current_index = variable_pool.get([self.node_id, "index"]) @@ -247,7 +247,11 @@ class IterationNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IterationNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -273,15 +277,13 @@ class IterationNode(BaseNode): # variable selector to variable mapping try: # Get node class - from core.workflow.nodes.node_mapping import node_classes + from core.workflow.nodes.node_mapping import node_type_classes_mapping - node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type")) - node_cls = node_classes.get(node_type) + node_type = NodeType(sub_node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping.get(node_type) if not node_cls: continue - node_cls = cast(BaseNode, node_cls) - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=graph_config, config=sub_node_config ) diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 88b9665ac6..6ab7c30106 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,8 +1,9 @@ from collections.abc import Mapping, Sequence from typing import Any -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData from models.workflow import WorkflowNodeExecutionStatus diff --git a/api/core/workflow/nodes/knowledge_retrieval/__init__.py b/api/core/workflow/nodes/knowledge_retrieval/__init__.py index e69de29bb2..4d4a4cbd9f 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/__init__.py +++ b/api/core/workflow/nodes/knowledge_retrieval/__init__.py @@ -0,0 +1,3 @@ +from .knowledge_retrieval_node import KnowledgeRetrievalNode + +__all__ = ["KnowledgeRetrievalNode"] diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 1cd88039b1..e8972d1381 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -2,7 +2,7 @@ from typing import Any, Literal, Optional from pydantic import BaseModel -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 8cd208d7fc..b286f34d7f 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -14,8 +14,9 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -32,15 +33,13 @@ default_retrieval_model = { } -class KnowledgeRetrievalNode(BaseNode): +class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): _node_data_cls = KnowledgeRetrievalNodeData - node_type = NodeType.KNOWLEDGE_RETRIEVAL + _node_type = NodeType.KNOWLEDGE_RETRIEVAL def _run(self) -> NodeRunResult: - node_data = cast(KnowledgeRetrievalNodeData, self.node_data) - # extract variables - variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector) + variable = self.graph_runtime_state.variable_pool.get_any(self.node_data.query_variable_selector) query = variable variables = {"query": query} if not query: @@ -49,7 +48,7 @@ class KnowledgeRetrievalNode(BaseNode): ) # retrieve knowledge try: - results = self._fetch_dataset_retriever(node_data=node_data, query=query) + results = self._fetch_dataset_retriever(node_data=self.node_data, query=query) outputs = {"result": results} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs @@ -244,7 +243,11 @@ class KnowledgeRetrievalNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: KnowledgeRetrievalNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: KnowledgeRetrievalNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/list_operator/__init__.py b/api/core/workflow/nodes/list_operator/__init__.py new file mode 100644 index 0000000000..1877586ef4 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/__init__.py @@ -0,0 +1,3 @@ +from .node import ListOperatorNode + +__all__ = ["ListOperatorNode"] diff --git a/api/core/workflow/nodes/list_operator/entities.py b/api/core/workflow/nodes/list_operator/entities.py new file mode 100644 index 0000000000..b4c885fb3d --- /dev/null +++ b/api/core/workflow/nodes/list_operator/entities.py @@ -0,0 +1,56 @@ +from collections.abc import Sequence +from typing import Literal + +from pydantic import BaseModel, Field + +from core.workflow.nodes.base import BaseNodeData + +_Condition = Literal[ + # string conditions + "contains", + "startswith", + "endswith", + "is", + "in", + "empty", + "not contains", + "not is", + "not in", + "not empty", + # number conditions + "=", + "!=", + "<", + ">", + "≥", + "≤", +] + + +class FilterCondition(BaseModel): + key: str = "" + comparison_operator: _Condition = "contains" + value: str | Sequence[str] = "" + + +class FilterBy(BaseModel): + enabled: bool = False + conditions: Sequence[FilterCondition] = Field(default_factory=list) + + +class OrderBy(BaseModel): + enabled: bool = False + key: str = "" + value: Literal["asc", "desc"] = "asc" + + +class Limit(BaseModel): + enabled: bool = False + size: int = -1 + + +class ListOperatorNodeData(BaseNodeData): + variable: Sequence[str] = Field(default_factory=list) + filter_by: FilterBy + order_by: OrderBy + limit: Limit diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py new file mode 100644 index 0000000000..2a1492abf2 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/node.py @@ -0,0 +1,259 @@ +from collections.abc import Callable, Sequence +from typing import Literal + +from core.file import File +from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ListOperatorNodeData + + +class ListOperatorNode(BaseNode[ListOperatorNodeData]): + _node_data_cls = ListOperatorNodeData + _node_type = NodeType.LIST_OPERATOR + + def _run(self): + inputs = {} + process_data = {} + outputs = {} + + variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) + if variable is None: + error_message = f"Variable not found for selector: {self.node_data.variable}" + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs + ) + if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): + error_message = ( + f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " + "or ArrayStringSegment" + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs + ) + + if isinstance(variable, ArrayFileSegment): + process_data["variable"] = [item.to_dict() for item in variable.value] + else: + process_data["variable"] = variable.value + + # Filter + if self.node_data.filter_by.enabled: + for condition in self.node_data.filter_by.conditions: + if isinstance(variable, ArrayStringSegment): + if not isinstance(condition.value, str): + raise ValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayNumberSegment): + if not isinstance(condition.value, str): + raise ValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayFileSegment): + if isinstance(condition.value, str): + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + else: + value = condition.value + filter_func = _get_file_filter_func( + key=condition.key, + condition=condition.comparison_operator, + value=value, + ) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + + # Order + if self.node_data.order_by.enabled: + if isinstance(variable, ArrayStringSegment): + result = _order_string(order=self.node_data.order_by.value, array=variable.value) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayNumberSegment): + result = _order_number(order=self.node_data.order_by.value, array=variable.value) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayFileSegment): + result = _order_file( + order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + ) + variable = variable.model_copy(update={"value": result}) + + # Slice + if self.node_data.limit.enabled: + result = variable.value[: self.node_data.limit.size] + variable = variable.model_copy(update={"value": result}) + + outputs = { + "result": variable.value, + "first_record": variable.value[0] if variable.value else None, + "last_record": variable.value[-1] if variable.value else None, + } + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + + +def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: + match key: + case "size": + return lambda x: x.size + case _: + raise ValueError(f"Invalid key: {key}") + + +def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: + match key: + case "name": + return lambda x: x.filename or "" + case "type": + return lambda x: x.type + case "extension": + return lambda x: x.extension or "" + case "mimetype": + return lambda x: x.mime_type or "" + case "transfer_method": + return lambda x: x.transfer_method + case "urL": + return lambda x: x.remote_url or "" + case _: + raise ValueError(f"Invalid key: {key}") + + +def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: + match condition: + case "contains": + return _contains(value) + case "startswith": + return _startswith(value) + case "endswith": + return _endswith(value) + case "is": + return _is(value) + case "in": + return _in(value) + case "empty": + return lambda x: x == "" + case "not contains": + return lambda x: not _contains(value)(x) + case "not is": + return lambda x: not _is(value)(x) + case "not in": + return lambda x: not _in(value)(x) + case "not empty": + return lambda x: x != "" + case _: + raise ValueError(f"Invalid condition: {condition}") + + +def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: + match condition: + case "in": + return _in(value) + case "not in": + return lambda x: not _in(value)(x) + case _: + raise ValueError(f"Invalid condition: {condition}") + + +def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: + match condition: + case "=": + return _eq(value) + case "!=": + return _ne(value) + case "<": + return _lt(value) + case "≤": + return _le(value) + case ">": + return _gt(value) + case "≥": + return _ge(value) + case _: + raise ValueError(f"Invalid condition: {condition}") + + +def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: + if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str): + extract_func = _get_file_extract_string_func(key=key) + return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) + if key in {"type", "transfer_method"} and isinstance(value, Sequence): + extract_func = _get_file_extract_string_func(key=key) + return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) + elif key == "size" and isinstance(value, str): + extract_func = _get_file_extract_number_func(key=key) + return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) + else: + raise ValueError(f"Invalid key: {key}") + + +def _contains(value: str): + return lambda x: value in x + + +def _startswith(value: str): + return lambda x: x.startswith(value) + + +def _endswith(value: str): + return lambda x: x.endswith(value) + + +def _is(value: str): + return lambda x: x is value + + +def _in(value: str | Sequence[str]): + return lambda x: x in value + + +def _eq(value: int | float): + return lambda x: x == value + + +def _ne(value: int | float): + return lambda x: x != value + + +def _lt(value: int | float): + return lambda x: x < value + + +def _le(value: int | float): + return lambda x: x <= value + + +def _gt(value: int | float): + return lambda x: x > value + + +def _ge(value: int | float): + return lambda x: x >= value + + +def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]): + return sorted(array, key=lambda x: x, reverse=order == "desc") + + +def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]): + return sorted(array, key=lambda x: x, reverse=order == "desc") + + +def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]): + if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "urL"}: + extract_func = _get_file_extract_string_func(key=order_by) + return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + elif order_by == "size": + extract_func = _get_file_extract_number_func(key=order_by) + return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + else: + raise ValueError(f"Invalid order key: {order_by}") diff --git a/api/core/workflow/nodes/llm/__init__.py b/api/core/workflow/nodes/llm/__init__.py index e69de29bb2..f7bc713f63 100644 --- a/api/core/workflow/nodes/llm/__init__.py +++ b/api/core/workflow/nodes/llm/__init__.py @@ -0,0 +1,17 @@ +from .entities import ( + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from .node import LLMNode + +__all__ = [ + "LLMNode", + "LLMNodeChatModelMessage", + "LLMNodeCompletionModelPromptTemplate", + "LLMNodeData", + "ModelConfig", + "VisionConfig", +] diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 93ee0ac250..b4de312461 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,17 +1,15 @@ -from typing import Any, Literal, Optional, Union +from collections.abc import Sequence +from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field +from core.model_runtime.entities import ImagePromptMessageContent from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class ModelConfig(BaseModel): - """ - Model Config. - """ - provider: str name: str mode: str @@ -19,62 +17,36 @@ class ModelConfig(BaseModel): class ContextConfig(BaseModel): - """ - Context Config. - """ - enabled: bool variable_selector: Optional[list[str]] = None +class VisionConfigOptions(BaseModel): + variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"]) + detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH + + class VisionConfig(BaseModel): - """ - Vision Config. - """ - - class Configs(BaseModel): - """ - Configs. - """ - - detail: Literal["low", "high"] - - enabled: bool - configs: Optional[Configs] = None + enabled: bool = False + configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions) class PromptConfig(BaseModel): - """ - Prompt Config. - """ - jinja2_variables: Optional[list[VariableSelector]] = None class LLMNodeChatModelMessage(ChatModelMessage): - """ - LLM Node Chat Model Message. - """ - jinja2_text: Optional[str] = None class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - """ - LLM Node Chat Model Prompt Template. - """ - jinja2_text: Optional[str] = None class LLMNodeData(BaseNodeData): - """ - LLM Node Data. - """ - model: ModelConfig - prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate] + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_config: Optional[PromptConfig] = None memory: Optional[MemoryConfig] = None context: ContextConfig - vision: VisionConfig + vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/node.py similarity index 76% rename from api/core/workflow/nodes/llm/llm_node.py rename to api/core/workflow/nodes/llm/node.py index 3d336b0b0b..24e479153e 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,39 +1,40 @@ import json from collections.abc import Generator, Mapping, Sequence -from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional, cast -from pydantic import BaseModel - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( + AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, + TextPromptMessageContent, ) +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.nodes.llm.entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, - ModelConfig, +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import ( + ModelInvokeCompletedEvent, + NodeEvent, + RunCompletedEvent, + RunRetrieverResourceEvent, + RunStreamChunkEvent, ) from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db @@ -41,44 +42,34 @@ from models.model import Conversation from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus +from .entities import ( + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + LLMNodeData, + ModelConfig, +) + if TYPE_CHECKING: - from core.file.file_obj import FileVar + from core.file.models import File -class ModelInvokeCompleted(BaseModel): - """ - Model invoke completed - """ - - text: str - usage: LLMUsage - finish_reason: Optional[str] = None - - -class LLMNode(BaseNode): +class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: - """ - Run node - :return: - """ - node_data = cast(LLMNodeData, deepcopy(self.node_data)) - variable_pool = self.graph_runtime_state.variable_pool - + def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]: node_inputs = None process_data = None try: # init messages template - node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template) + self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data, variable_pool) + inputs = self._fetch_inputs(node_data=self.node_data) # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool) + jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) # merge inputs inputs.update(jinja_inputs) @@ -86,13 +77,17 @@ class LLMNode(BaseNode): node_inputs = {} # fetch files - files = self._fetch_files(node_data, variable_pool) + files = ( + self._fetch_files(selector=self.node_data.vision.configs.variable_selector) + if self.node_data.vision.enabled + else [] + ) if files: node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value - generator = self._fetch_context(node_data, variable_pool) + generator = self._fetch_context(node_data=self.node_data) context = None for event in generator: if isinstance(event, RunRetrieverResourceEvent): @@ -103,21 +98,30 @@ class LLMNode(BaseNode): node_inputs["#context#"] = context # type: ignore # fetch model config - model_instance, model_config = self._fetch_model_config(node_data.model) + model_instance, model_config = self._fetch_model_config(self.node_data.model) # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) # fetch prompt messages + if self.node_data.memory: + query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) + if not query: + raise ValueError("Query not found") + query = query.text + else: + query = None + prompt_messages, stop = self._fetch_prompt_messages( - node_data=node_data, - query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None, - query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, + system_query=query, inputs=inputs, files=files, context=context, memory=memory, model_config=model_config, + vision_detail=self.node_data.vision.configs.detail, + prompt_template=self.node_data.prompt_template, + memory_config=self.node_data.memory, ) process_data = { @@ -131,7 +135,7 @@ class LLMNode(BaseNode): # handle invoke result generator = self._invoke_llm( - node_data_model=node_data.model, + node_data_model=self.node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, @@ -143,7 +147,7 @@ class LLMNode(BaseNode): for event in generator: if isinstance(event, RunStreamChunkEvent): yield event - elif isinstance(event, ModelInvokeCompleted): + elif isinstance(event, ModelInvokeCompletedEvent): result_text = event.text usage = event.usage finish_reason = event.finish_reason @@ -182,15 +186,7 @@ class LLMNode(BaseNode): model_instance: ModelInstance, prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None, - ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: - """ - Invoke large language model - :param node_data_model: node data model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: - """ + ) -> Generator[NodeEvent, None, None]: db.session.close() invoke_result = model_instance.invoke_llm( @@ -207,20 +203,13 @@ class LLMNode(BaseNode): usage = LLMUsage.empty_usage() for event in generator: yield event - if isinstance(event, ModelInvokeCompleted): + if isinstance(event, ModelInvokeCompletedEvent): usage = event.usage # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - def _handle_invoke_result( - self, invoke_result: LLMResult | Generator - ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: - """ - Handle invoke result - :param invoke_result: invoke result - :return: - """ + def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: if isinstance(invoke_result, LLMResult): return @@ -250,18 +239,11 @@ class LLMNode(BaseNode): if not usage: usage = LLMUsage.empty_usage() - yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason) + yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason) def _transform_chat_messages( - self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - """ - Transform chat messages - - :param messages: chat messages - :return: - """ - + self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / + ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: if isinstance(messages, LLMNodeCompletionModelPromptTemplate): if messages.edition_type == "jinja2" and messages.jinja2_text: messages.text = messages.jinja2_text @@ -274,13 +256,7 @@ class LLMNode(BaseNode): return messages - def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: - """ - Fetch jinja inputs - :param node_data: node data - :param variable_pool: variable pool - :return: - """ + def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: variables = {} if not node_data.prompt_config: @@ -288,7 +264,7 @@ class LLMNode(BaseNode): for variable_selector in node_data.prompt_config.jinja2_variables or []: variable = variable_selector.variable - value = variable_pool.get_any(variable_selector.value_selector) + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) def parse_dict(d: dict) -> str: """ @@ -330,13 +306,7 @@ class LLMNode(BaseNode): return variables - def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: - """ - Fetch inputs - :param node_data: node data - :param variable_pool: variable pool - :return: - """ + def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]: inputs = {} prompt_template = node_data.prompt_template @@ -350,7 +320,7 @@ class LLMNode(BaseNode): variable_selectors = variable_template_parser.extract_variable_selectors() for variable_selector in variable_selectors: - variable_value = variable_pool.get_any(variable_selector.value_selector) + variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) if variable_value is None: raise ValueError(f"Variable {variable_selector.variable} not found") @@ -362,7 +332,7 @@ class LLMNode(BaseNode): template=memory.query_prompt_template ).extract_variable_selectors() for variable_selector in query_variable_selectors: - variable_value = variable_pool.get_any(variable_selector.value_selector) + variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) if variable_value is None: raise ValueError(f"Variable {variable_selector.variable} not found") @@ -370,36 +340,28 @@ class LLMNode(BaseNode): return inputs - def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]: - """ - Fetch files - :param node_data: node data - :param variable_pool: variable pool - :return: - """ - if not node_data.vision.enabled: + def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]: + variable = self.graph_runtime_state.variable_pool.get(selector) + if variable is None: return [] - - files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value]) - if not files: + if isinstance(variable, FileSegment): + return [variable.value] + if isinstance(variable, ArrayFileSegment): + return variable.value + # FIXME: Temporary fix for empty array, + # all variables added to variable pool should be a Segment instance. + if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0: return [] + raise ValueError(f"Invalid variable type: {type(variable)}") - return files - - def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]: - """ - Fetch context - :param node_data: node data - :param variable_pool: variable pool - :return: - """ + def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: return if not node_data.context.variable_selector: return - context_value = variable_pool.get_any(node_data.context.variable_selector) + context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector) if context_value: if isinstance(context_value, str): yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value) @@ -424,11 +386,6 @@ class LLMNode(BaseNode): ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: - """ - Convert to original retriever resource, temp. - :param context_dict: context dict - :return: - """ if ( "metadata" in context_dict and "_source" in context_dict["metadata"] @@ -451,6 +408,7 @@ class LLMNode(BaseNode): "segment_position": metadata.get("segment_position"), "index_node_hash": metadata.get("segment_index_node_hash"), "content": context_dict.get("content"), + "page": metadata.get("page"), } return source @@ -460,11 +418,6 @@ class LLMNode(BaseNode): def _fetch_model_config( self, node_data_model: ModelConfig ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - """ - Fetch model config - :param node_data_model: node data model - :return: - """ model_name = node_data_model.name provider_name = node_data_model.provider @@ -523,19 +476,15 @@ class LLMNode(BaseNode): ) def _fetch_memory( - self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance + self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance ) -> Optional[TokenBufferMemory]: - """ - Fetch memory - :param node_data_memory: node data memory - :param variable_pool: variable pool - :return: - """ if not node_data_memory: return None # get conversation id - conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value]) + conversation_id = self.graph_runtime_state.variable_pool.get_any( + ["sys", SystemVariableKey.CONVERSATION_ID.value] + ) if conversation_id is None: return None @@ -555,43 +504,31 @@ class LLMNode(BaseNode): def _fetch_prompt_messages( self, - node_data: LLMNodeData, - query: Optional[str], - query_prompt_template: Optional[str], - inputs: dict[str, str], - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], + *, + system_query: str | None = None, + inputs: dict[str, str] | None = None, + files: Sequence["File"], + context: str | None = None, + memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + memory_config: MemoryConfig | None = None, + vision_detail: ImagePromptMessageContent.DETAIL, ) -> tuple[list[PromptMessage], Optional[list[str]]]: - """ - Fetch prompt messages - :param node_data: node data - :param query: query - :param query_prompt_template: query prompt template - :param inputs: inputs - :param files: files - :param context: context - :param memory: memory - :param model_config: model config - :return: - """ + inputs = inputs or {} + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_messages = prompt_transform.get_prompt( - prompt_template=node_data.prompt_template, + prompt_template=prompt_template, inputs=inputs, - query=query or "", + query=system_query or "", files=files, context=context, - memory_config=node_data.memory, + memory_config=memory_config, memory=memory, model_config=model_config, - query_prompt_template=query_prompt_template, ) stop = model_config.stop - - vision_enabled = node_data.vision.enabled - vision_detail = node_data.vision.configs.detail if node_data.vision.configs else None filtered_prompt_messages = [] for prompt_message in prompt_messages: if prompt_message.is_empty(): @@ -599,17 +536,13 @@ class LLMNode(BaseNode): if not isinstance(prompt_message.content, str): prompt_message_content = [] - for content_item in prompt_message.content: - if ( - vision_enabled - and content_item.type == PromptMessageContentType.IMAGE - and isinstance(content_item, ImagePromptMessageContent) - ): - # Override vision config if LLM node has vision config - if vision_detail: - content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) + for content_item in prompt_message.content or []: + if isinstance(content_item, ImagePromptMessageContent): + # Override vision config if LLM node has vision config, + # cuz vision detail is related to the configuration from FileUpload feature. + content_item.detail = vision_detail prompt_message_content.append(content_item) - elif content_item.type == PromptMessageContentType.TEXT: + elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent): prompt_message_content.append(content_item) if len(prompt_message_content) > 1: @@ -631,13 +564,6 @@ class LLMNode(BaseNode): @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: - """ - Deduct LLM quota - :param tenant_id: tenant id - :param model_instance: model instance - :param usage: usage - :return: - """ provider_model_bundle = model_instance.provider_model_bundle provider_configuration = provider_model_bundle.configuration @@ -668,7 +594,7 @@ class LLMNode(BaseNode): else: used_quota = 1 - if used_quota is not None: + if used_quota is not None and system_configuration.current_quota_type is not None: db.session.query(Provider).filter( Provider.tenant_id == tenant_id, Provider.provider_name == model_instance.provider, @@ -680,27 +606,28 @@ class LLMNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: LLMNodeData, ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ prompt_template = node_data.prompt_template variable_selectors = [] - if isinstance(prompt_template, list): + if isinstance(prompt_template, list) and all( + isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template + ): for prompt in prompt_template: if prompt.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - else: + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): if prompt_template.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() + else: + raise ValueError(f"Invalid prompt template type: {type(prompt_template)}") variable_mapping = {} for variable_selector in variable_selectors: @@ -745,11 +672,6 @@ class LLMNode(BaseNode): @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ return { "type": "llm", "config": { diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index a8a0debe64..b7cd7a948e 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,4 +1,4 @@ -from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState +from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState class LoopNodeData(BaseIterationNodeData): diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index fbc68b79cb..6fdff96602 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,12 +1,12 @@ from typing import Any -from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.loop.entities import LoopNodeData, LoopState from core.workflow.utils.condition.entities import Condition -class LoopNode(BaseNode): +class LoopNode(BaseNode[LoopNodeData]): """ Loop Node. """ diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index b98525e86e..c13b5ff76f 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -1,22 +1,24 @@ -from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.http_request.http_request_node import HttpRequestNode -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.iteration.iteration_node import IterationNode -from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode +from core.workflow.nodes.answer import AnswerNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.code import CodeNode +from core.workflow.nodes.document_extractor import DocumentExtractorNode +from core.workflow.nodes.end import EndNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.http_request import HttpRequestNode +from core.workflow.nodes.if_else import IfElseNode +from core.workflow.nodes.iteration import IterationNode, IterationStartNode +from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode +from core.workflow.nodes.list_operator import ListOperatorNode +from core.workflow.nodes.llm import LLMNode +from core.workflow.nodes.parameter_extractor import ParameterExtractorNode +from core.workflow.nodes.question_classifier import QuestionClassifierNode +from core.workflow.nodes.start import StartNode +from core.workflow.nodes.template_transform import TemplateTransformNode +from core.workflow.nodes.tool import ToolNode +from core.workflow.nodes.variable_aggregator import VariableAggregatorNode from core.workflow.nodes.variable_assigner import VariableAssignerNode -node_classes = { +node_type_classes_mapping: dict[NodeType, type[BaseNode]] = { NodeType.START: StartNode, NodeType.END: EndNode, NodeType.ANSWER: AnswerNode, @@ -34,4 +36,6 @@ node_classes = { NodeType.ITERATION_START: IterationStartNode, NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, + NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode, + NodeType.LIST_OPERATOR: ListOperatorNode, } diff --git a/api/core/workflow/nodes/parameter_extractor/__init__.py b/api/core/workflow/nodes/parameter_extractor/__init__.py index e69de29bb2..bdbf19a7d3 100644 --- a/api/core/workflow/nodes/parameter_extractor/__init__.py +++ b/api/core/workflow/nodes/parameter_extractor/__init__.py @@ -0,0 +1,3 @@ +from .parameter_extractor_node import ParameterExtractorNode + +__all__ = ["ParameterExtractorNode"] diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 5697d7c049..a001b44dc7 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -1,20 +1,10 @@ from typing import Any, Literal, Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class ModelConfig(BaseModel): - """ - Model Config. - """ - - provider: str - name: str - mode: str - completion_params: dict[str, Any] = {} +from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.llm import ModelConfig, VisionConfig class ParameterConfig(BaseModel): @@ -49,6 +39,7 @@ class ParameterExtractorNodeData(BaseNodeData): instruction: Optional[str] = None memory: Optional[MemoryConfig] = None reasoning_mode: Literal["function_call", "prompt"] + vision: VisionConfig = Field(default_factory=VisionConfig) @field_validator("reasoning_mode", mode="before") @classmethod @@ -64,7 +55,7 @@ class ParameterExtractorNodeData(BaseNodeData): parameters = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: - parameter_schema = {"description": parameter.description} + parameter_schema: dict[str, Any] = {"description": parameter.description} if parameter.type in {"string", "select"}: parameter_schema["type"] = "string" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index a6454bd1cd..49546e9356 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.file import File from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -22,12 +23,16 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.llm.entities import ModelConfig -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from core.workflow.nodes.parameter_extractor.prompts import ( +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.llm import LLMNode, ModelConfig +from core.workflow.utils import variable_template_parser +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ParameterExtractorNodeData +from .prompts import ( CHAT_EXAMPLE, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, COMPLETION_GENERATE_JSON_PROMPT, @@ -36,9 +41,6 @@ from core.workflow.nodes.parameter_extractor.prompts import ( FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT, FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, ) -from core.workflow.utils.variable_template_parser import VariableTemplateParser -from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus class ParameterExtractorNode(LLMNode): @@ -65,33 +67,39 @@ class ParameterExtractorNode(LLMNode): } } - def _run(self) -> NodeRunResult: + def _run(self): """ Run the node. """ node_data = cast(ParameterExtractorNodeData, self.node_data) - variable = self.graph_runtime_state.variable_pool.get_any(node_data.query) - if not variable: - raise ValueError("Input variable content not found or is empty") - query = variable + variable = self.graph_runtime_state.variable_pool.get(node_data.query) + query = variable.text if variable else "" - inputs = { - "query": query, - "parameters": jsonable_encoder(node_data.parameters), - "instruction": jsonable_encoder(node_data.instruction), - } + files = ( + self._fetch_files( + selector=node_data.vision.configs.variable_selector, + ) + if node_data.vision.enabled + else [] + ) model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): raise ValueError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance - model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) + model_schema = llm_model.get_model_schema( + model=model_config.model, + credentials=model_config.credentials, + ) if not model_schema: raise ValueError("Model schema not found") # fetch memory - memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance) + memory = self._fetch_memory( + node_data_memory=node_data.memory, + model_instance=model_instance, + ) if ( set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} @@ -99,15 +107,33 @@ class ParameterExtractorNode(LLMNode): ): # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data, query, self.graph_runtime_state.variable_pool, model_config, memory + node_data=node_data, + query=query, + variable_pool=self.graph_runtime_state.variable_pool, + model_config=model_config, + memory=memory, + files=files, ) else: # use prompt engineering prompt_messages = self._generate_prompt_engineering_prompt( - node_data, query, self.graph_runtime_state.variable_pool, model_config, memory + data=node_data, + query=query, + variable_pool=self.graph_runtime_state.variable_pool, + model_config=model_config, + memory=memory, + files=files, ) + prompt_message_tools = [] + inputs = { + "query": query, + "files": [f.to_dict() for f in files], + "parameters": jsonable_encoder(node_data.parameters), + "instruction": jsonable_encoder(node_data.instruction), + } + process_data = { "model_mode": model_config.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( @@ -119,7 +145,7 @@ class ParameterExtractorNode(LLMNode): } try: - text, usage, tool_call = self._invoke_llm( + text, usage, tool_call = self._invoke( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, @@ -150,12 +176,12 @@ class ParameterExtractorNode(LLMNode): error = "Failed to extract result from function call or text response, using empty result." try: - result = self._validate_result(node_data, result) + result = self._validate_result(data=node_data, result=result or {}) except Exception as e: error = str(e) # transform result into standard format - result = self._transform_result(node_data, result) + result = self._transform_result(data=node_data, result=result or {}) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -170,7 +196,7 @@ class ParameterExtractorNode(LLMNode): llm_usage=usage, ) - def _invoke_llm( + def _invoke( self, node_data_model: ModelConfig, model_instance: ModelInstance, @@ -178,14 +204,6 @@ class ParameterExtractorNode(LLMNode): tools: list[PromptMessageTool], stop: list[str], ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: - """ - Invoke large language model - :param node_data_model: node data model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: - """ db.session.close() invoke_result = model_instance.invoke_llm( @@ -202,6 +220,9 @@ class ParameterExtractorNode(LLMNode): raise ValueError(f"Invalid invoke result: {invoke_result}") text = invoke_result.message.content + if not isinstance(text, str): + raise ValueError(f"Invalid text content type: {type(text)}. Expected str.") + usage = invoke_result.usage tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None @@ -217,6 +238,7 @@ class ParameterExtractorNode(LLMNode): variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], + files: Sequence[File], ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. @@ -234,7 +256,7 @@ class ParameterExtractorNode(LLMNode): prompt_template=prompt_template, inputs={}, query="", - files=[], + files=files, context="", memory_config=node_data.memory, memory=None, @@ -296,6 +318,7 @@ class ParameterExtractorNode(LLMNode): variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], + files: Sequence[File], ) -> list[PromptMessage]: """ Generate prompt engineering prompt. @@ -303,9 +326,23 @@ class ParameterExtractorNode(LLMNode): model_mode = ModelMode.value_of(data.model.mode) if model_mode == ModelMode.COMPLETION: - return self._generate_prompt_engineering_completion_prompt(data, query, variable_pool, model_config, memory) + return self._generate_prompt_engineering_completion_prompt( + node_data=data, + query=query, + variable_pool=variable_pool, + model_config=model_config, + memory=memory, + files=files, + ) elif model_mode == ModelMode.CHAT: - return self._generate_prompt_engineering_chat_prompt(data, query, variable_pool, model_config, memory) + return self._generate_prompt_engineering_chat_prompt( + node_data=data, + query=query, + variable_pool=variable_pool, + model_config=model_config, + memory=memory, + files=files, + ) else: raise ValueError(f"Invalid model mode: {model_mode}") @@ -316,20 +353,23 @@ class ParameterExtractorNode(LLMNode): variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], + files: Sequence[File], ) -> list[PromptMessage]: """ Generate completion prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + rest_token = self._calculate_rest_token( + node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" + ) prompt_template = self._get_prompt_engineering_prompt_template( - node_data, query, variable_pool, memory, rest_token + node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, query="", - files=[], + files=files, context="", memory_config=node_data.memory, memory=memory, @@ -345,27 +385,30 @@ class ParameterExtractorNode(LLMNode): variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], + files: Sequence[File], ) -> list[PromptMessage]: """ Generate chat prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + rest_token = self._calculate_rest_token( + node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" + ) prompt_template = self._get_prompt_engineering_prompt_template( - node_data, - CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + node_data=node_data, + query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( structure=json.dumps(node_data.get_parameter_json_schema()), text=query ), - variable_pool, - memory, - rest_token, + variable_pool=variable_pool, + memory=memory, + max_token_limit=rest_token, ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, query="", - files=[], + files=files, context="", memory_config=node_data.memory, memory=None, @@ -425,10 +468,11 @@ class ParameterExtractorNode(LLMNode): raise ValueError(f"Invalid `string` value for parameter {parameter.name}") if parameter.type.startswith("array"): - if not isinstance(result.get(parameter.name), list): + parameters = result.get(parameter.name) + if not isinstance(parameters, list): raise ValueError(f"Invalid `array` value for parameter {parameter.name}") nested_type = parameter.type[6:-1] - for item in result.get(parameter.name): + for item in parameters: if nested_type == "number" and not isinstance(item, int | float): raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") if nested_type == "string" and not isinstance(item, str): @@ -565,18 +609,6 @@ class ParameterExtractorNode(LLMNode): return result - def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str: - """ - Render instruction. - """ - variable_template_parser = VariableTemplateParser(instruction) - inputs = {} - for selector in variable_template_parser.extract_variable_selectors(): - variable = variable_pool.get_any(selector.value_selector) - inputs[selector.variable] = variable - - return variable_template_parser.format(inputs) - def _get_function_calling_prompt_template( self, node_data: ParameterExtractorNodeData, @@ -588,9 +620,9 @@ class ParameterExtractorNode(LLMNode): model_mode = ModelMode.value_of(node_data.model.mode) input_text = query memory_str = "" - instruction = self._render_instruction(node_data.instruction or "", variable_pool) + instruction = variable_pool.convert_template(node_data.instruction or "").text - if memory: + if memory and node_data.memory and node_data.memory.window: memory_str = memory.get_history_prompt_text( max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) @@ -611,13 +643,13 @@ class ParameterExtractorNode(LLMNode): variable_pool: VariablePool, memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, - ) -> list[ChatModelMessage]: + ): model_mode = ModelMode.value_of(node_data.model.mode) input_text = query memory_str = "" - instruction = self._render_instruction(node_data.instruction or "", variable_pool) + instruction = variable_pool.convert_template(node_data.instruction or "").text - if memory: + if memory and node_data.memory and node_data.memory.window: memory_str = memory.get_history_prompt_text( max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) @@ -691,7 +723,7 @@ class ParameterExtractorNode(LLMNode): ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens @@ -712,7 +744,11 @@ class ParameterExtractorNode(LLMNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: ParameterExtractorNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ParameterExtractorNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -721,11 +757,11 @@ class ParameterExtractorNode(LLMNode): :param node_data: node data :return: """ - variable_mapping = {"query": node_data.query} + variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) - for selector in variable_template_parser.extract_variable_selectors(): + selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) + for selector in selectors: variable_mapping[selector.variable] = selector.value_selector variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} diff --git a/api/core/workflow/nodes/question_classifier/__init__.py b/api/core/workflow/nodes/question_classifier/__init__.py index e69de29bb2..70414c4199 100644 --- a/api/core/workflow/nodes/question_classifier/__init__.py +++ b/api/core/workflow/nodes/question_classifier/__init__.py @@ -0,0 +1,4 @@ +from .entities import QuestionClassifierNodeData +from .question_classifier_node import QuestionClassifierNode + +__all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"] diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index 40f7ce7582..5219f11d26 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -1,39 +1,21 @@ -from typing import Any, Optional +from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class ModelConfig(BaseModel): - """ - Model Config. - """ - - provider: str - name: str - mode: str - completion_params: dict[str, Any] = {} +from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.llm import ModelConfig, VisionConfig class ClassConfig(BaseModel): - """ - Class Config. - """ - id: str name: str class QuestionClassifierNodeData(BaseNodeData): - """ - Knowledge retrieval Node Data. - """ - query_variable_selector: list[str] - type: str = "question-classifier" model: ModelConfig classes: list[ClassConfig] instruction: Optional[str] = None memory: Optional[MemoryConfig] = None + vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 2ae58bc5f7..e6af453dcf 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,25 +1,30 @@ import json import logging from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole -from core.model_runtime.entities.model_entities import ModelPropertyKey +from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted -from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData -from core.workflow.nodes.question_classifier.template_prompts import ( +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import ModelInvokeCompletedEvent +from core.workflow.nodes.llm import ( + LLMNode, + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser +from libs.json_in_md_parser import parse_and_check_json_markdown +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import QuestionClassifierNodeData +from .template_prompts import ( QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, QUESTION_CLASSIFIER_COMPLETION_PROMPT, @@ -28,46 +33,77 @@ from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_2, QUESTION_CLASSIFIER_USER_PROMPT_3, ) -from core.workflow.utils.variable_template_parser import VariableTemplateParser -from libs.json_in_md_parser import parse_and_check_json_markdown -from models.workflow import WorkflowNodeExecutionStatus + +if TYPE_CHECKING: + from core.file import File class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData - node_type = NodeType.QUESTION_CLASSIFIER + _node_type = NodeType.QUESTION_CLASSIFIER - def _run(self) -> NodeRunResult: - node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) - node_data = cast(QuestionClassifierNodeData, node_data) + def _run(self): + node_data = cast(QuestionClassifierNodeData, self.node_data) variable_pool = self.graph_runtime_state.variable_pool # extract variables - variable = variable_pool.get(node_data.query_variable_selector) + variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None query = variable.value if variable else None variables = {"query": query} # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + memory = self._fetch_memory( + node_data_memory=node_data.memory, + model_instance=model_instance, + ) # fetch instruction - instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else "" - node_data.instruction = instruction + node_data.instruction = node_data.instruction or "" + node_data.instruction = variable_pool.convert_template(node_data.instruction).text + + files: Sequence[File] = ( + self._fetch_files( + selector=node_data.vision.configs.variable_selector, + ) + if node_data.vision.enabled + else [] + ) + # fetch prompt messages - prompt_messages, stop = self._fetch_prompt( - node_data=node_data, context="", query=query, memory=memory, model_config=model_config + rest_token = self._calculate_rest_token( + node_data=node_data, + query=query or "", + model_config=model_config, + context="", + ) + prompt_template = self._get_prompt_template( + node_data=node_data, + query=query or "", + memory=memory, + max_token_limit=rest_token, + ) + prompt_messages, stop = self._fetch_prompt_messages( + prompt_template=prompt_template, + system_query=query, + memory=memory, + model_config=model_config, + files=files, + vision_detail=node_data.vision.configs.detail, ) # handle invoke result generator = self._invoke_llm( - node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop + node_data_model=node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, ) result_text = "" usage = LLMUsage.empty_usage() finish_reason = None for event in generator: - if isinstance(event, ModelInvokeCompleted): + if isinstance(event, ModelInvokeCompletedEvent): result_text = event.text usage = event.usage finish_reason = event.finish_reason @@ -129,7 +165,11 @@ class QuestionClassifierNode(LLMNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: QuestionClassifierNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: QuestionClassifierNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -159,40 +199,6 @@ class QuestionClassifierNode(LLMNode): """ return {"type": "question-classifier", "config": {"instructions": ""}} - def _fetch_prompt( - self, - node_data: QuestionClassifierNodeData, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: - """ - Fetch prompt - :param node_data: node data - :param query: inputs - :param context: context - :param memory: memory - :param model_config: model config - :return: - """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, model_config, context) - prompt_template = self._get_prompt_template(node_data, query, memory, rest_token) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=[], - context=context, - memory_config=node_data.memory, - memory=None, - model_config=model_config, - ) - stop = model_config.stop - - return prompt_messages, stop - def _calculate_rest_token( self, node_data: QuestionClassifierNodeData, @@ -229,7 +235,7 @@ class QuestionClassifierNode(LLMNode): ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens @@ -243,7 +249,7 @@ class QuestionClassifierNode(LLMNode): query: str, memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, - ) -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: + ): model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes categories = [] @@ -255,31 +261,32 @@ class QuestionClassifierNode(LLMNode): memory_str = "" if memory: memory_str = memory.get_history_prompt_text( - max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + max_token_limit=max_token_limit, + message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, ) - prompt_messages = [] + prompt_messages: list[LLMNodeChatModelMessage] = [] if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( + system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) prompt_messages.append(system_prompt_messages) - user_prompt_message_1 = ChatModelMessage( + user_prompt_message_1 = LLMNodeChatModelMessage( role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 ) prompt_messages.append(user_prompt_message_1) - assistant_prompt_message_1 = ChatModelMessage( + assistant_prompt_message_1 = LLMNodeChatModelMessage( role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 ) prompt_messages.append(assistant_prompt_message_1) - user_prompt_message_2 = ChatModelMessage( + user_prompt_message_2 = LLMNodeChatModelMessage( role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 ) prompt_messages.append(user_prompt_message_2) - assistant_prompt_message_2 = ChatModelMessage( + assistant_prompt_message_2 = LLMNodeChatModelMessage( role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 ) prompt_messages.append(assistant_prompt_message_2) - user_prompt_message_3 = ChatModelMessage( + user_prompt_message_3 = LLMNodeChatModelMessage( role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( input_text=input_text, @@ -290,7 +297,7 @@ class QuestionClassifierNode(LLMNode): prompt_messages.append(user_prompt_message_3) return prompt_messages elif model_mode == ModelMode.COMPLETION: - return CompletionModelPromptTemplate( + return LLMNodeCompletionModelPromptTemplate( text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( histories=memory_str, input_text=input_text, @@ -302,23 +309,3 @@ class QuestionClassifierNode(LLMNode): else: raise ValueError(f"Model mode {model_mode} not support.") - - def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str: - inputs = {} - - variable_selectors = [] - variable_template_parser = VariableTemplateParser(template=instruction) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - for variable_selector in variable_selectors: - variable = variable_pool.get(variable_selector.value_selector) - variable_value = variable.value if variable else None - if variable_value is None: - raise ValueError(f"Variable {variable_selector.variable} not found") - - inputs[variable_selector.variable] = variable_value - - prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - instruction = prompt_template.format(prompt_inputs) - return instruction diff --git a/api/core/workflow/nodes/start/__init__.py b/api/core/workflow/nodes/start/__init__.py index e69de29bb2..5411780423 100644 --- a/api/core/workflow/nodes/start/__init__.py +++ b/api/core/workflow/nodes/start/__init__.py @@ -0,0 +1,3 @@ +from .start_node import StartNode + +__all__ = ["StartNode"] diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 11d2ebe5dd..594d1b7bab 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from pydantic import Field from core.app.app_config.entities import VariableEntity -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class StartNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 96c887c58d..a7b91e82bb 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,25 +1,24 @@ from collections.abc import Mapping, Sequence from typing import Any -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID -from core.workflow.nodes.base_node import BaseNode +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus -class StartNode(BaseNode): +class StartNode(BaseNode[StartNodeData]): _node_data_cls = StartNodeData _node_type = NodeType.START def _run(self) -> NodeRunResult: - """ - Run node - :return: - """ node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) system_inputs = self.graph_runtime_state.variable_pool.system_variables + # TODO: System variables should be directly accessible, no need for special handling + # Set system variables as node outputs. for var in system_inputs: node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] @@ -27,13 +26,10 @@ class StartNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: StartNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: StartNodeData, ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ return {} diff --git a/api/core/workflow/nodes/template_transform/__init__.py b/api/core/workflow/nodes/template_transform/__init__.py index e69de29bb2..43863b9d59 100644 --- a/api/core/workflow/nodes/template_transform/__init__.py +++ b/api/core/workflow/nodes/template_transform/__init__.py @@ -0,0 +1,3 @@ +from .template_transform_node import TemplateTransformNode + +__all__ = ["TemplateTransformNode"] diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py index e934d69fa3..96adff6ffa 100644 --- a/api/core/workflow/nodes/template_transform/entities.py +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -1,5 +1,5 @@ -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class TemplateTransformNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 32c99e0d1c..857a693c5b 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,17 +1,18 @@ import os from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, Optional from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) -class TemplateTransformNode(BaseNode): +class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): _node_data_cls = TemplateTransformNodeData _node_type = NodeType.TEMPLATE_TRANSFORM @@ -28,22 +29,16 @@ class TemplateTransformNode(BaseNode): } def _run(self) -> NodeRunResult: - """ - Run node - """ - node_data = self.node_data - node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data) - # Get variables variables = {} - for variable_selector in node_data.variables: + for variable_selector in self.node_data.variables: variable_name = variable_selector.variable value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) variables[variable_name] = value # Run code try: result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables + language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables ) except CodeExecutionError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) diff --git a/api/core/workflow/nodes/tool/__init__.py b/api/core/workflow/nodes/tool/__init__.py index e69de29bb2..f4982e655d 100644 --- a/api/core/workflow/nodes/tool/__init__.py +++ b/api/core/workflow/nodes/tool/__init__.py @@ -0,0 +1,3 @@ +from .tool_node import ToolNode + +__all__ = ["ToolNode"] diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 9d222b10b9..9e29791481 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -3,7 +3,7 @@ from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class ToolEntity(BaseModel): @@ -51,7 +51,4 @@ class ToolNodeData(BaseNodeData, ToolEntity): raise ValueError("value must be a string, int, float, or bool") return typ - """ - Tool Node Schema - """ tool_parameters: dict[str, ToolInput] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 3b86b29cf8..df22130d69 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,24 +1,28 @@ from collections.abc import Mapping, Sequence from os import path -from typing import Any, cast +from typing import Any + +from sqlalchemy import select +from sqlalchemy.orm import Session -from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.file.models import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser -from models import WorkflowNodeExecutionStatus +from extensions.ext_database import db +from models import ToolFile +from models.workflow import WorkflowNodeExecutionStatus -class ToolNode(BaseNode): +class ToolNode(BaseNode[ToolNodeData]): """ Tool Node """ @@ -27,37 +31,38 @@ class ToolNode(BaseNode): _node_type = NodeType.TOOL def _run(self) -> NodeRunResult: - """ - Run the tool node - """ - - node_data = cast(ToolNodeData, self.node_data) - # fetch tool icon - tool_info = {"provider_type": node_data.provider_type, "provider_id": node_data.provider_id} + tool_info = { + "provider_type": self.node_data.provider_type, + "provider_id": self.node_data.provider_id, + } # get tool runtime try: tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from + self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, - metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + metadata={ + NodeRunMetadataKey.TOOL_INFO: tool_info, + }, error=f"Failed to get tool runtime: {str(e)}", ) # get parameters tool_parameters = tool_runtime.get_runtime_parameters() or [] parameters = self._generate_parameters( - tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, ) parameters_for_log = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=node_data, + node_data=self.node_data, for_log=True, ) @@ -74,7 +79,9 @@ class ToolNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + metadata={ + NodeRunMetadataKey.TOOL_INFO: tool_info, + }, error=f"Failed to invoke tool: {str(e)}", ) @@ -83,8 +90,14 @@ class ToolNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": plain_text, "files": files, "json": json}, - metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + outputs={ + "text": plain_text, + "files": files, + "json": json, + }, + metadata={ + NodeRunMetadataKey.TOOL_INFO: tool_info, + }, inputs=parameters_for_log, ) @@ -116,29 +129,25 @@ class ToolNode(BaseNode): if not parameter: result[parameter_name] = None continue - if parameter.type == ToolParameter.ToolParameterType.FILE: - result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)] + tool_input = node_data.tool_parameters[parameter_name] + if tool_input.type == "variable": + variable = variable_pool.get(tool_input.value) + if variable is None: + raise ValueError(f"variable {tool_input.value} not exists") + parameter_value = variable.value + elif tool_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(tool_input.value)) + parameter_value = segment_group.log if for_log else segment_group.text else: - tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == "variable": - # TODO: check if the variable exists in the variable pool - parameter_value = variable_pool.get(tool_input.value).value - else: - segment_group = parser.convert_template( - template=str(tool_input.value), - variable_pool=variable_pool, - ) - parameter_value = segment_group.log if for_log else segment_group.text - result[parameter_name] = parameter_value + raise ValueError(f"unknown tool input type '{tool_input.type}'") + result[parameter_name] = parameter_value return result - def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) - assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) - return list(variable.value) if variable else [] - - def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar], list[dict]]: + def _convert_tool_messages( + self, + messages: list[ToolInvokeMessage], + ): """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -156,50 +165,86 @@ class ToolNode(BaseNode): return plain_text, files, json - def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]: + def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[File]: """ Extract tool response binary """ result = [] - for response in tool_response: if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: - url = response.message - ext = path.splitext(url)[1] - mimetype = response.meta.get("mime_type", "image/jpeg") - filename = response.save_as or url.split("/")[-1] + url = str(response.message) if response.message else None + ext = path.splitext(url)[1] if url else ".bin" + tool_file_id = str(url).split("/")[-1].split(".")[0] transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - # get tool file id - tool_file_id = url.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ValueError(f"tool file {tool_file_id} not exists") + result.append( - FileVar( + File( tenant_id=self.tenant_id, type=FileType.IMAGE, transfer_method=transfer_method, - url=url, - related_id=tool_file_id, - filename=filename, + remote_url=url, + related_id=tool_file.id, + filename=tool_file.name, extension=ext, - mime_type=mimetype, + mime_type=tool_file.mimetype, + size=tool_file.size, ) ) elif response.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id - tool_file_id = response.message.split("/")[-1].split(".")[0] + tool_file_id = str(response.message).split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ValueError(f"tool file {tool_file_id} not exists") result.append( - FileVar( + File( tenant_id=self.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file_id, - filename=response.save_as, + related_id=tool_file.id, + filename=tool_file.name, extension=path.splitext(response.save_as)[1], - mime_type=response.meta.get("mime_type", "application/octet-stream"), + mime_type=tool_file.mimetype, + size=tool_file.size, ) ) elif response.type == ToolInvokeMessage.MessageType.LINK: - pass # TODO: + url = str(response.message) + transfer_method = FileTransferMethod.TOOL_FILE + tool_file_id = url.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ValueError(f"tool file {tool_file_id} not exists") + if "." in url: + extension = "." + url.split("/")[-1].split(".")[1] + else: + extension = ".bin" + file = File( + tenant_id=self.tenant_id, + type=FileType(response.save_as), + transfer_method=transfer_method, + remote_url=url, + filename=tool_file.name, + related_id=tool_file.id, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + ) + result.append(file) + + elif response.type == ToolInvokeMessage.MessageType.FILE: + assert response.meta is not None + result.append(response.meta["file"]) return result @@ -218,12 +263,16 @@ class ToolNode(BaseNode): ] ) - def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: + def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]): return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON] @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ToolNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -236,7 +285,7 @@ class ToolNode(BaseNode): for parameter_name in node_data.tool_parameters: input = node_data.tool_parameters[parameter_name] if input.type == "mixed": - selectors = VariableTemplateParser(input.value).extract_variable_selectors() + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() for selector in selectors: result[selector.variable] = selector.value_selector elif input.type == "variable": diff --git a/api/core/workflow/nodes/variable_aggregator/__init__.py b/api/core/workflow/nodes/variable_aggregator/__init__.py index e69de29bb2..0b6bf2a5b6 100644 --- a/api/core/workflow/nodes/variable_aggregator/__init__.py +++ b/api/core/workflow/nodes/variable_aggregator/__init__.py @@ -0,0 +1,3 @@ +from .variable_aggregator_node import VariableAggregatorNode + +__all__ = ["VariableAggregatorNode"] diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index eb893a04e3..71a930e6b0 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -2,7 +2,7 @@ from typing import Literal, Optional from pydantic import BaseModel -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class AdvancedSettings(BaseModel): diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index f03eae257a..05477e2a90 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,24 +1,24 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData from models.workflow import WorkflowNodeExecutionStatus -class VariableAggregatorNode(BaseNode): +class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_AGGREGATOR def _run(self) -> NodeRunResult: - node_data = cast(VariableAssignerNodeData, self.node_data) # Get variables outputs = {} inputs = {} - if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled: - for selector in node_data.variables: + if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: + for selector in self.node_data.variables: variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: outputs = {"output": variable} @@ -26,7 +26,7 @@ class VariableAggregatorNode(BaseNode): inputs = {".".join(selector[1:]): variable} break else: - for group in node_data.advanced_settings.groups: + for group in self.node_data.advanced_settings.groups: for selector in group.variables: variable = self.graph_runtime_state.variable_pool.get_any(selector) diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py index 3969299795..4e66f640df 100644 --- a/api/core/workflow/nodes/variable_assigner/node.py +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -1,40 +1,38 @@ -from typing import cast - from sqlalchemy import select from sqlalchemy.orm import Session -from core.app.segments import SegmentType, Variable, factory -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode +from core.variables import SegmentType, Variable +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode, BaseNodeData +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models import ConversationVariable, WorkflowNodeExecutionStatus +from factories import variable_factory +from models import ConversationVariable +from models.workflow import WorkflowNodeExecutionStatus from .exc import VariableAssignerNodeError from .node_data import VariableAssignerData, WriteMode -class VariableAssignerNode(BaseNode): +class VariableAssignerNode(BaseNode[VariableAssignerData]): _node_data_cls: type[BaseNodeData] = VariableAssignerData _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER def _run(self) -> NodeRunResult: - data = cast(VariableAssignerData, self.node_data) - # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector) + original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableAssignerNodeError("assigned variable not found") - match data.write_mode: + match self.node_data.write_mode: case WriteMode.OVER_WRITE: - income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) if not income_value: raise VariableAssignerNodeError("input value not found") updated_variable = original_variable.model_copy(update={"value": income_value.value}) case WriteMode.APPEND: - income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) if not income_value: raise VariableAssignerNodeError("input value not found") updated_value = original_variable.value + [income_value.value] @@ -45,10 +43,10 @@ class VariableAssignerNode(BaseNode): updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: - raise VariableAssignerNodeError(f"unsupported write mode: {data.write_mode}") + raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}") # Over write the variable. - self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable) + self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. @@ -80,12 +78,12 @@ def update_conversation_variable(conversation_id: str, variable: Variable): def get_zero_value(t: SegmentType): match t: case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: - return factory.build_segment([]) + return variable_factory.build_segment([]) case SegmentType.OBJECT: - return factory.build_segment({}) + return variable_factory.build_segment({}) case SegmentType.STRING: - return factory.build_segment("") + return variable_factory.build_segment("") case SegmentType.NUMBER: - return factory.build_segment(0) + return variable_factory.build_segment(0) case _: raise VariableAssignerNodeError(f"unsupported variable type: {t}") diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py index 8ac8eadf7c..70ae29d45f 100644 --- a/api/core/workflow/nodes/variable_assigner/node_data.py +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from enum import Enum from typing import Optional -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class WriteMode(str, Enum): diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index b8e8b881a5..1d96743879 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -1,32 +1,46 @@ -from typing import Literal, Optional +from collections.abc import Sequence +from typing import Literal -from pydantic import BaseModel +from pydantic import BaseModel, Field + +SupportedComparisonOperator = Literal[ + # for string or array + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", + "in", + "not in", + "all of", + # for number + "=", + "≠", + ">", + "<", + "≥", + "≤", + "null", + "not null", +] + + +class SubCondition(BaseModel): + key: str + comparison_operator: SupportedComparisonOperator + value: str | Sequence[str] | None = None + + +class SubVariableCondition(BaseModel): + logical_operator: Literal["and", "or"] + conditions: list[SubCondition] = Field(default=list) class Condition(BaseModel): - """ - Condition entity - """ - variable_selector: list[str] - comparison_operator: Literal[ - # for string or array - "contains", - "not contains", - "start with", - "end with", - "is", - "is not", - "empty", - "not empty", - # for number - "=", - "≠", - ">", - "<", - "≥", - "≤", - "null", - "not null", - ] - value: Optional[str] = None + comparison_operator: SupportedComparisonOperator + value: str | Sequence[str] | None = None + sub_variable_condition: SubVariableCondition | None = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index 395ee82478..f4a80fa5e1 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -1,381 +1,362 @@ from collections.abc import Sequence -from typing import Any, Optional +from typing import Any, Literal -from core.file.file_obj import FileVar +from core.file import FileAttribute, file_manager +from core.variables.segments import ArrayFileSegment from core.workflow.entities.variable_pool import VariablePool -from core.workflow.utils.condition.entities import Condition -from core.workflow.utils.variable_template_parser import VariableTemplateParser + +from .entities import Condition, SubCondition, SupportedComparisonOperator class ConditionProcessor: - def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): - input_conditions = [] - group_result = [] - - index = 0 - for condition in conditions: - index += 1 - actual_value = variable_pool.get_any(condition.variable_selector) - - expected_value = None - if condition.value is not None: - variable_template_parser = VariableTemplateParser(template=condition.value) - variable_selectors = variable_template_parser.extract_variable_selectors() - if variable_selectors: - for variable_selector in variable_selectors: - value = variable_pool.get_any(variable_selector.value_selector) - expected_value = variable_template_parser.format({variable_selector.variable: value}) - - if expected_value is None: - expected_value = condition.value - else: - expected_value = condition.value - - comparison_operator = condition.comparison_operator - input_conditions.append( - { - "actual_value": actual_value, - "expected_value": expected_value, - "comparison_operator": comparison_operator, - } - ) - - result = self.evaluate_condition(actual_value, comparison_operator, expected_value) - group_result.append(result) - - return input_conditions, group_result - - def evaluate_condition( + def process_conditions( self, - actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], - comparison_operator: str, - expected_value: Optional[str] = None, - ) -> bool: - """ - Evaluate condition - :param actual_value: actual value - :param expected_value: expected value - :param comparison_operator: comparison operator + *, + variable_pool: VariablePool, + conditions: Sequence[Condition], + operator: Literal["and", "or"], + ): + input_conditions = [] + group_results = [] - :return: bool - """ - if comparison_operator == "contains": - return self._assert_contains(actual_value, expected_value) - elif comparison_operator == "not contains": - return self._assert_not_contains(actual_value, expected_value) - elif comparison_operator == "start with": - return self._assert_start_with(actual_value, expected_value) - elif comparison_operator == "end with": - return self._assert_end_with(actual_value, expected_value) - elif comparison_operator == "is": - return self._assert_is(actual_value, expected_value) - elif comparison_operator == "is not": - return self._assert_is_not(actual_value, expected_value) - elif comparison_operator == "empty": - return self._assert_empty(actual_value) - elif comparison_operator == "not empty": - return self._assert_not_empty(actual_value) - elif comparison_operator == "=": - return self._assert_equal(actual_value, expected_value) - elif comparison_operator == "≠": - return self._assert_not_equal(actual_value, expected_value) - elif comparison_operator == ">": - return self._assert_greater_than(actual_value, expected_value) - elif comparison_operator == "<": - return self._assert_less_than(actual_value, expected_value) - elif comparison_operator == "≥": - return self._assert_greater_than_or_equal(actual_value, expected_value) - elif comparison_operator == "≤": - return self._assert_less_than_or_equal(actual_value, expected_value) - elif comparison_operator == "null": - return self._assert_null(actual_value) - elif comparison_operator == "not null": - return self._assert_not_null(actual_value) - else: - raise ValueError(f"Invalid comparison operator: {comparison_operator}") + for condition in conditions: + variable = variable_pool.get(condition.variable_selector) - def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False + if isinstance(variable, ArrayFileSegment) and condition.comparison_operator in { + "contains", + "not contains", + "all of", + }: + # check sub conditions + if not condition.sub_variable_condition: + raise ValueError("Sub variable is required") + result = _process_sub_conditions( + variable=variable, + sub_conditions=condition.sub_variable_condition.conditions, + operator=condition.sub_variable_condition.logical_operator, + ) + else: + actual_value = variable.value if variable else None + expected_value = condition.value + if isinstance(expected_value, str): + expected_value = variable_pool.convert_template(expected_value).text + input_conditions.append( + { + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": condition.comparison_operator, + } + ) + result = _evaluate_condition( + value=actual_value, + operator=condition.comparison_operator, + expected=expected_value, + ) + group_results.append(result) - if not isinstance(actual_value, str | list): - raise ValueError("Invalid actual value type: string or array") + final_result = all(group_results) if operator == "and" else any(group_results) + return input_conditions, group_results, final_result - if expected_value not in actual_value: - return False - return True - def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert not contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return True +def _evaluate_condition( + *, + operator: SupportedComparisonOperator, + value: Any, + expected: str | Sequence[str] | None, +) -> bool: + match operator: + case "contains": + return _assert_contains(value=value, expected=expected) + case "not contains": + return _assert_not_contains(value=value, expected=expected) + case "start with": + return _assert_start_with(value=value, expected=expected) + case "end with": + return _assert_end_with(value=value, expected=expected) + case "is": + return _assert_is(value=value, expected=expected) + case "is not": + return _assert_is_not(value=value, expected=expected) + case "empty": + return _assert_empty(value=value) + case "not empty": + return _assert_not_empty(value=value) + case "=": + return _assert_equal(value=value, expected=expected) + case "≠": + return _assert_not_equal(value=value, expected=expected) + case ">": + return _assert_greater_than(value=value, expected=expected) + case "<": + return _assert_less_than(value=value, expected=expected) + case "≥": + return _assert_greater_than_or_equal(value=value, expected=expected) + case "≤": + return _assert_less_than_or_equal(value=value, expected=expected) + case "null": + return _assert_null(value=value) + case "not null": + return _assert_not_null(value=value) + case "in": + return _assert_in(value=value, expected=expected) + case "not in": + return _assert_not_in(value=value, expected=expected) + case "all of" if isinstance(expected, list): + return _assert_all_of(value=value, expected=expected) + case _: + raise ValueError(f"Unsupported operator: {operator}") - if not isinstance(actual_value, str | list): - raise ValueError("Invalid actual value type: string or array") - if expected_value in actual_value: - return False - return True - - def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert start with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError("Invalid actual value type: string") - - if not actual_value.startswith(expected_value): - return False - return True - - def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert end with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError("Invalid actual value type: string") - - if not actual_value.endswith(expected_value): - return False - return True - - def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError("Invalid actual value type: string") - - if actual_value != expected_value: - return False - return True - - def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is not - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError("Invalid actual value type: string") - - if actual_value == expected_value: - return False - return True - - def _assert_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert empty - :param actual_value: actual value - :return: - """ - if not actual_value: - return True +def _assert_contains(*, value: Any, expected: Any) -> bool: + if not value: return False - def _assert_not_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert not empty - :param actual_value: actual value - :return: - """ - if actual_value: - return True + if not isinstance(value, str | list): + raise ValueError("Invalid actual value type: string or array") + + if expected not in value: + return False + return True + + +def _assert_not_contains(*, value: Any, expected: Any) -> bool: + if not value: + return True + + if not isinstance(value, str | list): + raise ValueError("Invalid actual value type: string or array") + + if expected in value: + return False + return True + + +def _assert_start_with(*, value: Any, expected: Any) -> bool: + if not value: return False - def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: - """ - Assert equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") + if not value.startswith(expected): + return False + return True - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - if actual_value != expected_value: - return False - return True - - def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: - """ - Assert not equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value == expected_value: - return False - return True - - def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: - """ - Assert greater than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value <= expected_value: - return False - return True - - def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: - """ - Assert less than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value >= expected_value: - return False - return True - - def _assert_greater_than_or_equal( - self, actual_value: Optional[int | float], expected_value: str | int | float - ) -> bool: - """ - Assert greater than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value < expected_value: - return False - return True - - def _assert_less_than_or_equal( - self, actual_value: Optional[int | float], expected_value: str | int | float - ) -> bool: - """ - Assert less than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError("Invalid actual value type: number") - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value > expected_value: - return False - return True - - def _assert_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert null - :param actual_value: actual value - :return: - """ - if actual_value is None: - return True +def _assert_end_with(*, value: Any, expected: Any) -> bool: + if not value: return False - def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert not null - :param actual_value: actual value - :return: - """ - if actual_value is not None: - return True + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") + + if not value.endswith(expected): + return False + return True + + +def _assert_is(*, value: Any, expected: Any) -> bool: + if value is None: return False + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") -class ConditionAssertionError(Exception): - def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None: - self.message = message - self.conditions = conditions - self.sub_condition_compare_results = sub_condition_compare_results - super().__init__(self.message) + if value != expected: + return False + return True + + +def _assert_is_not(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") + + if value == expected: + return False + return True + + +def _assert_empty(*, value: Any) -> bool: + if not value: + return True + return False + + +def _assert_not_empty(*, value: Any) -> bool: + if value: + return True + return False + + +def _assert_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value != expected: + return False + return True + + +def _assert_not_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value == expected: + return False + return True + + +def _assert_greater_than(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value <= expected: + return False + return True + + +def _assert_less_than(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value >= expected: + return False + return True + + +def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value < expected: + return False + return True + + +def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value > expected: + return False + return True + + +def _assert_null(*, value: Any) -> bool: + if value is None: + return True + return False + + +def _assert_not_null(*, value: Any) -> bool: + if value is not None: + return True + return False + + +def _assert_in(*, value: Any, expected: Any) -> bool: + if not value: + return False + + if not isinstance(expected, list): + raise ValueError("Invalid expected value type: array") + + if value not in expected: + return False + return True + + +def _assert_not_in(*, value: Any, expected: Any) -> bool: + if not value: + return True + + if not isinstance(expected, list): + raise ValueError("Invalid expected value type: array") + + if value in expected: + return False + return True + + +def _assert_all_of(*, value: Any, expected: Sequence[str]) -> bool: + if not value: + return False + + if not all(item in value for item in expected): + return False + return True + + +def _process_sub_conditions( + variable: ArrayFileSegment, + sub_conditions: Sequence[SubCondition], + operator: Literal["and", "or"], +) -> bool: + files = variable.value + group_results = [] + for condition in sub_conditions: + key = FileAttribute(condition.key) + values = [file_manager.get_attr(file=file, attr=key) for file in files] + sub_group_results = [ + _evaluate_condition( + value=value, + operator=condition.comparison_operator, + expected=condition.value, + ) + for value in values + ] + # Determine the result based on the presence of "not" in the comparison operator + result = all(sub_group_results) if "not" in condition.comparison_operator else any(sub_group_results) + group_results.append(result) + return all(group_results) if operator == "and" else any(group_results) diff --git a/api/core/workflow/utils/variable_template_parser.py b/api/core/workflow/utils/variable_template_parser.py index fd0e48b862..1d8fb38ebf 100644 --- a/api/core/workflow/utils/variable_template_parser.py +++ b/api/core/workflow/utils/variable_template_parser.py @@ -1,42 +1,21 @@ import re -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") +SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") -def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: - """ - This is an alternative to the VariableTemplateParser class, - offering the same functionality but with better readability and ease of use. - """ - variable_keys = [match[0] for match in re.findall(REGEX, template)] - variable_keys = list(set(variable_keys)) - # This key_selector is a tuple of (key, selector) where selector is a list of keys - # e.g. ('#node_id.query.name#', ['node_id', 'query', 'name']) - key_selectors = filter( - lambda t: len(t[1]) >= 2, - ((key, selector.replace("#", "").split(".")) for key, selector in zip(variable_keys, variable_keys)), - ) - inputs = {key: variable_pool.get_any(selector) for key, selector in key_selectors} - - def replacer(match): - key = match.group(1) - # return original matched string if key not found - value = inputs.get(key, match.group(0)) - if value is None: - value = "" - value = str(value) - # remove template variables if required - return re.sub(REGEX, r"{\1}", value) - - result = re.sub(REGEX, replacer, template) - result = re.sub(r"<\|.*?\|>", "", result) - return result +def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]: + parts = SELECTOR_PATTERN.split(template) + selectors = [] + for part in filter(lambda x: x, parts): + if "." in part and part[0] == "#" and part[-1] == "#": + selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split("."))) + return selectors class VariableTemplateParser: diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index d3576197d1..eb812bad21 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -8,10 +8,8 @@ from configs import dify_config from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileTransferMethod, FileType, FileVar -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType, UserFrom +from core.file.models import File, FileTransferMethod, FileType, ImageConfig +from core.workflow.callbacks import WorkflowCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent @@ -19,10 +17,12 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.event import RunEvent -from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.node_mapping import node_classes +from core.workflow.nodes import NodeType +from core.workflow.nodes.base import BaseNode, BaseNodeData +from core.workflow.nodes.event import NodeEvent +from core.workflow.nodes.llm import LLMNodeData +from core.workflow.nodes.node_mapping import node_type_classes_mapping +from models.enums import UserFrom from models.workflow import ( Workflow, WorkflowType, @@ -115,7 +115,7 @@ class WorkflowEntry: @classmethod def single_step_run( cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict - ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: + ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Single step run workflow node :param workflow: Workflow instance @@ -144,8 +144,8 @@ class WorkflowEntry: raise ValueError("node id not found in workflow graph") # Get node class - node_type = NodeType.value_of(node_config.get("data", {}).get("type")) - node_cls = node_classes.get(node_type) + node_type = NodeType(node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping.get(node_type) node_cls = cast(type[BaseNode], node_cls) if not node_cls: @@ -162,7 +162,7 @@ class WorkflowEntry: graph = Graph.init(graph_config=workflow.graph_dict) # init workflow run state - node_instance: BaseNode = node_cls( + node_instance = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -205,32 +205,27 @@ class WorkflowEntry: except Exception as e: raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) - @classmethod - def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: - """ - Handle special values - :param value: value - :return: - """ - if not value: - return None + @staticmethod + def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: + return WorkflowEntry._handle_special_values(value) - new_value = dict(value) if value else {} - if isinstance(new_value, dict): - for key, val in new_value.items(): - if isinstance(val, FileVar): - new_value[key] = val.to_dict() - elif isinstance(val, list): - new_val = [] - for v in val: - if isinstance(v, FileVar): - new_val.append(v.to_dict()) - else: - new_val.append(v) - - new_value[key] = new_val - - return new_value + @staticmethod + def _handle_special_values(value: Any) -> Any: + if value is None: + return value + if isinstance(value, dict): + res = {} + for k, v in value.items(): + res[k] = WorkflowEntry._handle_special_values(v) + return res + if isinstance(value, list): + res = [] + for item in value: + res.append(WorkflowEntry._handle_special_values(item)) + return res + if isinstance(value, File): + return value.to_dict() + return value @classmethod def mapping_user_inputs_to_variable_pool( @@ -276,15 +271,19 @@ class WorkflowEntry: for item in input_value: if isinstance(item, dict) and "type" in item and item["type"] == "image": transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) - file = FileVar( + file = File( tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=transfer_method, - url=item.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, + remote_url=item.get("url") + if transfer_method == FileTransferMethod.REMOTE_URL + else None, related_id=item.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=FileExtraConfig(image_config={"detail": detail} if detail else None), + _extra_config=FileExtraConfig( + image_config=ImageConfig(detail=detail) if detail else None + ), ) new_value.append(file) diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index f96bb5ef74..9c5955c8c5 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,6 +1,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolEntity from events.app_event import app_draft_workflow_was_synced diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index c5e98e263f..453395e8d7 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,6 +1,6 @@ from typing import cast -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes import NodeType from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index f90629262d..5fc4f88832 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -72,7 +72,7 @@ class Storage: logging.exception("Failed to save file: %s", e) raise e - def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]: + def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: try: if stream: return self.load_stream(filename) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py new file mode 100644 index 0000000000..91b188270a --- /dev/null +++ b/api/factories/file_factory.py @@ -0,0 +1,254 @@ +import mimetypes +from collections.abc import Mapping, Sequence +from typing import Any + +from sqlalchemy import select + +from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS +from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType +from core.helper import ssrf_proxy +from extensions.ext_database import db +from models import MessageFile, ToolFile, UploadFile +from models.enums import CreatedByRole + + +def build_from_message_files( + *, + message_files: Sequence["MessageFile"], + tenant_id: str, + config: FileExtraConfig, +) -> Sequence[File]: + results = [ + build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) + for file in message_files + if file.belongs_to != FileBelongsTo.ASSISTANT + ] + return results + + +def build_from_message_file( + *, + message_file: "MessageFile", + tenant_id: str, + config: FileExtraConfig, +): + mapping = { + "transfer_method": message_file.transfer_method, + "url": message_file.url, + "id": message_file.id, + "type": message_file.type, + "upload_file_id": message_file.upload_file_id, + } + return build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + user_id=message_file.created_by, + role=CreatedByRole(message_file.created_by_role), + config=config, + ) + + +def build_from_mapping( + *, + mapping: Mapping[str, Any], + tenant_id: str, + user_id: str, + role: "CreatedByRole", + config: FileExtraConfig, +): + transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) + match transfer_method: + case FileTransferMethod.REMOTE_URL: + file = _build_from_remote_url( + mapping=mapping, + tenant_id=tenant_id, + config=config, + transfer_method=transfer_method, + ) + case FileTransferMethod.LOCAL_FILE: + file = _build_from_local_file( + mapping=mapping, + tenant_id=tenant_id, + user_id=user_id, + role=role, + config=config, + transfer_method=transfer_method, + ) + case FileTransferMethod.TOOL_FILE: + file = _build_from_tool_file( + mapping=mapping, + tenant_id=tenant_id, + user_id=user_id, + config=config, + transfer_method=transfer_method, + ) + case _: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + + return file + + +def build_from_mappings( + *, + mappings: Sequence[Mapping[str, Any]], + config: FileExtraConfig | None, + tenant_id: str, + user_id: str, + role: "CreatedByRole", +) -> Sequence[File]: + if not config: + return [] + + files = [ + build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + user_id=user_id, + role=role, + config=config, + ) + for mapping in mappings + ] + + if ( + # If image config is set. + config.image_config + # And the number of image files exceeds the maximum limit + and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits + ): + raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") + if config.number_limits and len(files) > config.number_limits: + raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") + + return files + + +def _build_from_local_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + user_id: str, + role: "CreatedByRole", + config: FileExtraConfig, + transfer_method: FileTransferMethod, +): + # check if the upload file exists. + file_type = FileType.value_of(mapping.get("type")) + stmt = select(UploadFile).where( + UploadFile.id == mapping.get("upload_file_id"), + UploadFile.tenant_id == tenant_id, + UploadFile.created_by == user_id, + UploadFile.created_by_role == role, + ) + if file_type == FileType.IMAGE: + stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS)) + elif file_type == FileType.VIDEO: + stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS)) + elif file_type == FileType.AUDIO: + stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS)) + elif file_type == FileType.DOCUMENT: + stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS)) + row = db.session.scalar(stmt) + if row is None: + raise ValueError("Invalid upload file") + file = File( + id=mapping.get("id"), + filename=row.name, + extension=row.extension, + mime_type=row.mime_type, + tenant_id=tenant_id, + type=file_type, + transfer_method=transfer_method, + remote_url=None, + related_id=mapping.get("upload_file_id"), + _extra_config=config, + size=row.size, + ) + return file + + +def _build_from_remote_url( + *, + mapping: Mapping[str, Any], + tenant_id: str, + config: FileExtraConfig, + transfer_method: FileTransferMethod, +): + url = mapping.get("url") + if not url: + raise ValueError("Invalid file url") + resp = ssrf_proxy.head(url) + resp.raise_for_status() + + # Try to extract filename from response headers or URL + content_disposition = resp.headers.get("Content-Disposition") + if content_disposition: + filename = content_disposition.split("filename=")[-1].strip('"') + else: + filename = url.split("/")[-1].split("?")[0] + # If filename is empty, set a default one + if not filename: + filename = "unknown_file" + + # Determine file extension + extension = "." + filename.split(".")[-1] if "." in filename else ".bin" + + # Create the File object + file_size = int(resp.headers.get("Content-Length", -1)) + mime_type = str(resp.headers.get("Content-Type", "")) + if not mime_type: + mime_type, _ = mimetypes.guess_type(url) + file = File( + id=mapping.get("id"), + filename=filename, + tenant_id=tenant_id, + type=FileType.value_of(mapping.get("type")), + transfer_method=transfer_method, + remote_url=url, + _extra_config=config, + mime_type=mime_type, + extension=extension, + size=file_size, + ) + return file + + +def _build_from_tool_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + user_id: str, + config: FileExtraConfig, + transfer_method: FileTransferMethod, +): + tool_file = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == mapping.get("tool_file_id"), + ToolFile.tenant_id == tenant_id, + ToolFile.user_id == user_id, + ) + .first() + ) + if tool_file is None: + raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") + + path = tool_file.file_key + if "." in path: + extension = "." + path.split("/")[-1].split(".")[-1] + else: + extension = ".bin" + file = File( + id=mapping.get("id"), + tenant_id=tenant_id, + filename=tool_file.name, + type=FileType.value_of(mapping.get("type")), + transfer_method=transfer_method, + remote_url=tool_file.original_url, + related_id=tool_file.id, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + _extra_config=config, + ) + return file diff --git a/api/core/app/segments/factory.py b/api/factories/variable_factory.py similarity index 73% rename from api/core/app/segments/factory.py rename to api/factories/variable_factory.py index 40a69ed4eb..a758f9981f 100644 --- a/api/core/app/segments/factory.py +++ b/api/factories/variable_factory.py @@ -2,29 +2,32 @@ from collections.abc import Mapping from typing import Any from configs import dify_config - -from .exc import VariableError -from .segments import ( +from core.file import File +from core.variables import ( ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayNumberVariable, + ArrayObjectSegment, + ArrayObjectVariable, + ArrayStringSegment, + ArrayStringVariable, + FileSegment, FloatSegment, + FloatVariable, IntegerSegment, + IntegerVariable, NoneSegment, ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - FloatVariable, - IntegerVariable, ObjectVariable, SecretVariable, + Segment, + SegmentType, + StringSegment, StringVariable, Variable, ) +from core.variables.exc import VariableError def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: @@ -71,6 +74,22 @@ def build_segment(value: Any, /) -> Segment: return FloatSegment(value=value) if isinstance(value, dict): return ObjectSegment(value=value) + if isinstance(value, File): + return FileSegment(value=value) if isinstance(value, list): - return ArrayAnySegment(value=value) + items = [build_segment(item) for item in value] + types = {item.value_type for item in items} + if len(types) != 1: + return ArrayAnySegment(value=value) + match types.pop(): + case SegmentType.STRING: + return ArrayStringSegment(value=value) + case SegmentType.NUMBER: + return ArrayNumberSegment(value=value) + case SegmentType.OBJECT: + return ArrayObjectSegment(value=value) + case SegmentType.FILE: + return ArrayFileSegment(value=value) + case _: + raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}") diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 3dcd88d1de..bf1c491a05 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -3,6 +3,8 @@ from flask_restful import fields from fields.member_fields import simple_account_fields from libs.helper import TimestampField +from .raws import FilesContainedField + class MessageTextField(fields.Raw): def format(self, value): @@ -33,8 +35,12 @@ annotation_hit_history_fields = { message_file_fields = { "id": fields.String, + "filename": fields.String, "type": fields.String, "url": fields.String, + "mime_type": fields.String, + "size": fields.Integer, + "transfer_method": fields.String, "belongs_to": fields.String(default="user"), } @@ -55,7 +61,7 @@ agent_thought_fields = { message_detail_fields = { "id": fields.String, "conversation_id": fields.String, - "inputs": fields.Raw, + "inputs": FilesContainedField, "query": fields.String, "message": fields.Raw, "message_tokens": fields.Integer, @@ -71,7 +77,7 @@ message_detail_fields = { "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True), "created_at": TimestampField, "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "metadata": fields.Raw(attribute="message_metadata_dict"), "status": fields.String, "error": fields.String, @@ -99,7 +105,7 @@ simple_model_config_fields = { } simple_message_detail_fields = { - "inputs": fields.Raw, + "inputs": FilesContainedField, "query": fields.String, "message": MessageTextField, "answer": fields.String, @@ -187,7 +193,7 @@ conversation_detail_fields = { simple_conversation_fields = { "id": fields.String, "name": fields.String, - "inputs": fields.Raw, + "inputs": FilesContainedField, "status": fields.String, "introduction": fields.String, "created_at": TimestampField, diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index e5a03ce77e..4ce7644e9d 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -17,3 +17,8 @@ file_fields = { "created_by": fields.String, "created_at": TimestampField, } + +remote_file_info_fields = { + "file_type": fields.String(attribute="file_type"), + "file_length": fields.Integer(attribute="file_length"), +} diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index c938097131..5f6e7884a6 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,6 +3,8 @@ from flask_restful import fields from fields.conversation_fields import message_file_fields from libs.helper import TimestampField +from .raws import FilesContainedField + feedback_fields = {"rating": fields.String} retriever_resource_fields = { @@ -63,14 +65,14 @@ message_fields = { "id": fields.String, "conversation_id": fields.String, "parent_message_id": fields.String, - "inputs": fields.Raw, + "inputs": FilesContainedField, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "created_at": TimestampField, "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "status": fields.String, "error": fields.String, } diff --git a/api/fields/raws.py b/api/fields/raws.py new file mode 100644 index 0000000000..15ec16ab13 --- /dev/null +++ b/api/fields/raws.py @@ -0,0 +1,17 @@ +from flask_restful import fields + +from core.file import File + + +class FilesContainedField(fields.Raw): + def format(self, value): + return self._format_file_object(value) + + def _format_file_object(self, v): + if isinstance(v, File): + return v.model_dump() + if isinstance(v, dict): + return {k: self._format_file_object(vv) for k, vv in v.items()} + if isinstance(v, list): + return [self._format_file_object(vv) for vv in v] + return v diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 2adef63ada..0d860d6f40 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restful import fields -from core.app.segments import SecretVariable, SegmentType, Variable from core.helper import encrypter +from core.variables import SecretVariable, SegmentType, Variable from fields.member_fields import simple_account_fields from libs.helper import TimestampField diff --git a/api/libs/helper.py b/api/libs/helper.py index fc3be8473c..81ac79bb04 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,7 +16,7 @@ from flask import Response, current_app, stream_with_context from flask_restful import fields from core.app.features.rate_limiting.rate_limit import RateLimitGenerator -from core.file.upload_file_parser import UploadFileParser +from core.file import helpers as file_helpers from extensions.ext_redis import redis_client from models.account import Account @@ -33,7 +33,7 @@ class AppIconUrlField(fields.Raw): from models.model import IconType if obj.icon_type == IconType.IMAGE.value: - return UploadFileParser.get_signed_temp_image_url(obj.icon) + return file_helpers.get_signed_file_url(obj.icon) return None diff --git a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py index be2c615525..6a7402b16a 100644 --- a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py +++ b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py @@ -8,7 +8,7 @@ Create Date: 2024-06-12 07:49:07.666510 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '04c602f5dc9b' @@ -20,8 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tracing_app_configs', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('app_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), sa.Column('tracing_provider', sa.String(length=255), nullable=True), sa.Column('tracing_config', sa.JSON(), nullable=True), sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), diff --git a/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py new file mode 100644 index 0000000000..c17d1db77a --- /dev/null +++ b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py @@ -0,0 +1,49 @@ +"""add name and size to tool_files + +Revision ID: bbadea11becb +Revises: 33f5fac87f29 +Create Date: 2024-10-10 05:16:14.764268 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'bbadea11becb' +down_revision = 'd8e744d88ed6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # Get the database connection + conn = op.get_bind() + + # Use SQLAlchemy inspector to get the columns of the 'tool_files' table + inspector = sa.inspect(conn) + columns = [col['name'] for col in inspector.get_columns('tool_files')] + + # If 'name' or 'size' columns already exist, exit the upgrade function + if 'name' in columns or 'size' in columns: + return + + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('name', sa.String(), nullable=True)) + batch_op.add_column(sa.Column('size', sa.Integer(), nullable=True)) + op.execute("UPDATE tool_files SET name = '' WHERE name IS NULL") + op.execute("UPDATE tool_files SET size = -1 WHERE size IS NULL") + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('name', existing_type=sa.String(), nullable=False) + batch_op.alter_column('size', existing_type=sa.Integer(), nullable=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.drop_column('size') + batch_op.drop_column('name') + # ### end Alembic commands ### diff --git a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py index db3119badf..bf54c247ea 100644 --- a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py +++ b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py @@ -8,7 +8,7 @@ Create Date: 2024-05-14 09:27:18.857890 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '3b18fea55204' @@ -20,7 +20,7 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tool_label_bindings', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('tool_id', sa.String(length=64), nullable=False), sa.Column('tool_type', sa.String(length=40), nullable=False), sa.Column('label_name', sa.String(length=40), nullable=False), diff --git a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py index 67d7b9fbf5..3be4ba4f2a 100644 --- a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py +++ b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py @@ -8,7 +8,7 @@ Create Date: 2024-05-10 12:08:09.812736 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '4e99a8df00ff' @@ -20,8 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('load_balancing_model_configs', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('provider_name', sa.String(length=255), nullable=False), sa.Column('model_name', sa.String(length=255), nullable=False), sa.Column('model_type', sa.String(length=40), nullable=False), @@ -36,8 +36,8 @@ def upgrade(): batch_op.create_index('load_balancing_model_config_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False) op.create_table('provider_model_settings', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('provider_name', sa.String(length=255), nullable=False), sa.Column('model_name', sa.String(length=255), nullable=False), sa.Column('model_type', sa.String(length=40), nullable=False), diff --git a/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py index f63bad9345..2ba0e13caa 100644 --- a/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py +++ b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py @@ -8,7 +8,7 @@ Create Date: 2024-05-14 07:31:29.702766 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '7b45942e39bb' @@ -20,8 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('data_source_api_key_auth_bindings', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('category', sa.String(length=255), nullable=False), sa.Column('provider', sa.String(length=255), nullable=False), sa.Column('credentials', sa.Text(), nullable=True), diff --git a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py index 67b61e5c76..f09a682f28 100644 --- a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py +++ b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py @@ -1,6 +1,6 @@ """add workflow tool -Revision ID: 7bdef072e63a +Revision ID: 7bdef072e63a Revises: 5fda94355fce Create Date: 2024-05-04 09:47:19.366961 @@ -8,7 +8,7 @@ Create Date: 2024-05-04 09:47:19.366961 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '7bdef072e63a' @@ -20,12 +20,12 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tool_workflow_providers', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('name', sa.String(length=40), nullable=False), sa.Column('icon', sa.String(length=255), nullable=False), - sa.Column('app_id', models.StringUUID(), nullable=False), - sa.Column('user_id', models.StringUUID(), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('user_id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('description', sa.Text(), nullable=False), sa.Column('parameter_configuration', sa.Text(), server_default='[]', nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), diff --git a/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py b/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py index ff53eb65a6..865572f3a7 100644 --- a/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py +++ b/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py @@ -8,7 +8,7 @@ Create Date: 2024-06-25 03:20:46.012193 import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '7e6a8693e07a' @@ -20,9 +20,9 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('dataset_permissions', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('dataset_id', models.StringUUID(), nullable=False), - sa.Column('account_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('account_id', models.types.StringUUID(), nullable=False), sa.Column('has_permission', sa.Boolean(), server_default=sa.text('true'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.PrimaryKeyConstraint('id', name='dataset_permission_pkey') diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py index 1ac44d083a..469c04338a 100644 --- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py +++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql -import models as models +import models.types # revision identifiers, used by Alembic. revision = 'c031d46af369' @@ -21,8 +21,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('trace_app_config', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('app_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), sa.Column('tracing_provider', sa.String(length=255), nullable=True), sa.Column('tracing_config', sa.JSON(), nullable=True), sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), diff --git a/api/models/__init__.py b/api/models/__init__.py index 30ceef057e..1d8bae6cfa 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,29 +1,55 @@ -from enum import Enum +from .account import Account, AccountIntegrate, InvitationCode, Tenant +from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment +from .model import ( + ApiToken, + App, + AppMode, + Conversation, + EndUser, + FileUploadConfig, + InstalledApp, + Message, + MessageAnnotation, + MessageFile, + RecommendedApp, + Site, + UploadFile, +) +from .source import DataSourceOauthBinding +from .tools import ToolFile +from .workflow import ( + ConversationVariable, + Workflow, + WorkflowAppLog, + WorkflowRun, +) -from .model import App, AppMode, Message -from .types import StringUUID -from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus - -__all__ = ["ConversationVariable", "StringUUID", "AppMode", "WorkflowNodeExecutionStatus", "Workflow", "App", "Message"] - - -class CreatedByRole(Enum): - """ - Enum class for createdByRole - """ - - ACCOUNT = "account" - END_USER = "end_user" - - @classmethod - def value_of(cls, value: str) -> "CreatedByRole": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for role in cls: - if role.value == value: - return role - raise ValueError(f"invalid createdByRole value {value}") +__all__ = [ + "ConversationVariable", + "Document", + "Dataset", + "DatasetProcessRule", + "DocumentSegment", + "DataSourceOauthBinding", + "AppMode", + "Workflow", + "App", + "Message", + "EndUser", + "MessageFile", + "UploadFile", + "Account", + "WorkflowAppLog", + "WorkflowRun", + "Site", + "InstalledApp", + "RecommendedApp", + "ApiToken", + "AccountIntegrate", + "InvitationCode", + "Tenant", + "Conversation", + "MessageAnnotation", + "FileUploadConfig", + "ToolFile", +] diff --git a/api/models/dataset.py b/api/models/dataset.py index 4224ee5e9c..4e2ccab7e8 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -560,7 +560,7 @@ class DocumentSegment(db.Model): ) def get_sign_content(self): - pattern = r"/files/([a-f0-9\-]+)/image-preview" + pattern = r"/files/([a-f0-9\-]+)/file-preview" text = self.content matches = re.finditer(pattern, text) signed_urls = [] @@ -568,7 +568,7 @@ class DocumentSegment(db.Model): upload_file_id = match.group(1) nonce = os.urandom(16).hex() timestamp = str(int(time.time())) - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() diff --git a/api/models/enums.py b/api/models/enums.py new file mode 100644 index 0000000000..a83d35e042 --- /dev/null +++ b/api/models/enums.py @@ -0,0 +1,16 @@ +from enum import Enum + + +class CreatedByRole(str, Enum): + ACCOUNT = "account" + END_USER = "end_user" + + +class UserFrom(str, Enum): + ACCOUNT = "account" + END_USER = "end-user" + + +class WorkflowRunTriggeredFrom(str, Enum): + DEBUGGING = "debugging" + APP_RUN = "app-run" diff --git a/api/models/model.py b/api/models/model.py index 0ac9334321..cb2855bf72 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,24 +1,37 @@ import json import re import uuid +from collections.abc import Mapping, Sequence +from datetime import datetime from enum import Enum -from typing import Optional +from typing import Any, Literal, Optional from flask import request from flask_login import UserMixin +from pydantic import BaseModel, Field from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config +from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType +from core.file import helpers as file_helpers from core.file.tool_file_parser import ToolFileParser -from core.file.upload_file_parser import UploadFileParser from extensions.ext_database import db from libs.helper import generate_string +from models.enums import CreatedByRole from .account import Account, Tenant from .types import StringUUID +class FileUploadConfig(BaseModel): + enabled: bool = Field(default=False) + allowed_file_types: Sequence[FileType] = Field(default_factory=list) + allowed_extensions: Sequence[str] = Field(default_factory=list) + allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + number_limits: int = Field(default=0, gt=0, le=10) + + class DifySetup(db.Model): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -27,7 +40,7 @@ class DifySetup(db.Model): setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class AppMode(Enum): +class AppMode(str, Enum): COMPLETION = "completion" WORKFLOW = "workflow" CHAT = "chat" @@ -59,7 +72,7 @@ class App(db.Model): __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) mode = db.Column(db.String(255), nullable=False) @@ -530,7 +543,7 @@ class Conversation(db.Model): mode = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False) summary = db.Column(db.Text) - inputs = db.Column(db.JSON) + _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) introduction = db.Column(db.Text) system_instruction = db.Column(db.Text) system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) @@ -552,6 +565,28 @@ class Conversation(db.Model): is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + @property + def inputs(self): + inputs = self._inputs.copy() + for key, value in inputs.items(): + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + inputs[key] = File.model_validate(value) + elif isinstance(value, list) and all( + isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value + ): + inputs[key] = [File.model_validate(item) for item in value] + return inputs + + @inputs.setter + def inputs(self, value: Mapping[str, Any]): + inputs = dict(value) + for k, v in inputs.items(): + if isinstance(v, File): + inputs[k] = v.model_dump() + elif isinstance(v, list) and all(isinstance(item, File) for item in v): + inputs[k] = [item.model_dump() for item in v] + self._inputs = inputs + @property def model_config(self): model_config = {} @@ -700,13 +735,13 @@ class Message(db.Model): model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) - inputs = db.Column(db.JSON) - query = db.Column(db.Text, nullable=False) + _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) + query: Mapped[str] = db.Column(db.Text, nullable=False) message = db.Column(db.JSON, nullable=False) message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - answer = db.Column(db.Text, nullable=False) + answer: Mapped[str] = db.Column(db.Text, nullable=False) answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) @@ -717,15 +752,37 @@ class Message(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) error = db.Column(db.Text) message_metadata = db.Column(db.Text) - invoke_from = db.Column(db.String(255), nullable=True) + invoke_from: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) from_source = db.Column(db.String(255), nullable=False) - from_end_user_id = db.Column(StringUUID) - from_account_id = db.Column(StringUUID) + from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) + from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) + @property + def inputs(self): + inputs = self._inputs.copy() + for key, value in inputs.items(): + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + inputs[key] = File.model_validate(value) + elif isinstance(value, list) and all( + isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value + ): + inputs[key] = [File.model_validate(item) for item in value] + return inputs + + @inputs.setter + def inputs(self, value: Mapping[str, Any]): + inputs = dict(value) + for k, v in inputs.items(): + if isinstance(v, File): + inputs[k] = v.model_dump() + elif isinstance(v, list) and all(isinstance(item, File) for item in v): + inputs[k] = [item.model_dump() for item in v] + self._inputs = inputs + @property def re_sign_file_url_answer(self) -> str: if not self.answer: @@ -772,19 +829,29 @@ class Message(db.Model): sign_url = ToolFileParser.get_tool_file_manager().sign_file( tool_file_id=tool_file_id, extension=extension ) - else: + elif "file-preview" in url: # get upload file id - upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp=" + upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp=" result = re.search(upload_file_id_pattern, url) if not result: continue upload_file_id = result.group(1) - if not upload_file_id: continue - - sign_url = UploadFileParser.get_signed_temp_image_url(upload_file_id) + sign_url = file_helpers.get_signed_file_url(upload_file_id) + elif "image-preview" in url: + # image-preview is deprecated, use file-preview instead + upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp=" + result = re.search(upload_file_id_pattern, url) + if not result: + continue + upload_file_id = result.group(1) + if not upload_file_id: + continue + sign_url = file_helpers.get_signed_file_url(upload_file_id) + else: + continue re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url) @@ -870,50 +937,71 @@ class Message(db.Model): @property def message_files(self): - return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all() + from factories import file_factory - @property - def files(self): - message_files = self.message_files + message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all() + current_app = db.session.query(App).filter(App.id == self.app_id).first() + if not current_app: + raise ValueError(f"App {self.app_id} not found") - files = [] + files: list[File] = [] for message_file in message_files: - url = message_file.url - if message_file.type == "image": - if message_file.transfer_method == "local_file": - upload_file = ( - db.session.query(UploadFile).filter(UploadFile.id == message_file.upload_file_id).first() - ) - - url = UploadFileParser.get_image_data(upload_file=upload_file, force_url=True) - if message_file.transfer_method == "tool_file": - # get tool file id - tool_file_id = message_file.url.split("/")[-1] - # trim extension - tool_file_id = tool_file_id.split(".")[0] - - # get extension - if "." in message_file.url: - extension = f'.{message_file.url.split(".")[-1]}' - if len(extension) > 10: - extension = ".bin" - else: - extension = ".bin" - # add sign url - url = ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=tool_file_id, extension=extension - ) - - files.append( - { + if message_file.transfer_method == "local_file": + if message_file.upload_file_id is None: + raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "upload_file_id": message_file.upload_file_id, + "transfer_method": message_file.transfer_method, + "type": message_file.type, + }, + tenant_id=current_app.tenant_id, + user_id=self.from_account_id or self.from_end_user_id or "", + role=CreatedByRole(message_file.created_by_role), + config=FileExtraConfig(), + ) + elif message_file.transfer_method == "remote_url": + if message_file.url is None: + raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "type": message_file.type, + "transfer_method": message_file.transfer_method, + "url": message_file.url, + }, + tenant_id=current_app.tenant_id, + user_id=self.from_account_id or self.from_end_user_id or "", + role=CreatedByRole(message_file.created_by_role), + config=FileExtraConfig(), + ) + elif message_file.transfer_method == "tool_file": + mapping = { "id": message_file.id, "type": message_file.type, - "url": url, - "belongs_to": message_file.belongs_to or "user", + "transfer_method": message_file.transfer_method, + "tool_file_id": message_file.upload_file_id, } - ) + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=current_app.tenant_id, + user_id=self.from_account_id or self.from_end_user_id or "", + role=CreatedByRole(message_file.created_by_role), + config=FileExtraConfig(), + ) + else: + raise ValueError( + f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}" + ) + files.append(file) - return files + result = [ + {"belongs_to": message_file.belongs_to, **file.to_dict()} + for (file, message_file) in zip(files, message_files) + ] + + return result @property def workflow_run(self): @@ -1003,16 +1091,39 @@ class MessageFile(db.Model): db.Index("message_file_created_by_idx", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - message_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - transfer_method = db.Column(db.String(255), nullable=False) - url = db.Column(db.Text, nullable=True) - belongs_to = db.Column(db.String(255), nullable=True) - upload_file_id = db.Column(StringUUID, nullable=True) - created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + def __init__( + self, + *, + message_id: str, + type: FileType, + transfer_method: FileTransferMethod, + url: str | None = None, + belongs_to: Literal["user", "assistant"] | None = None, + upload_file_id: str | None = None, + created_by_role: CreatedByRole, + created_by: str, + ): + self.message_id = message_id + self.type = type + self.transfer_method = transfer_method + self.url = url + self.belongs_to = belongs_to + self.upload_file_id = upload_file_id + self.created_by_role = created_by_role + self.created_by = created_by + + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + message_id: Mapped[str] = db.Column(StringUUID, nullable=False) + type: Mapped[str] = db.Column(db.String(255), nullable=False) + transfer_method: Mapped[str] = db.Column(db.String(255), nullable=False) + url: Mapped[Optional[str]] = db.Column(db.Text, nullable=True) + belongs_to: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) + upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) + created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False) + created_by: Mapped[str] = db.Column(StringUUID, nullable=False) + created_at: Mapped[datetime] = db.Column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) class MessageAnnotation(db.Model): @@ -1250,21 +1361,58 @@ class UploadFile(db.Model): db.Index("upload_file_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - storage_type = db.Column(db.String(255), nullable=False) - key = db.Column(db.String(255), nullable=False) - name = db.Column(db.String(255), nullable=False) - size = db.Column(db.Integer, nullable=False) - extension = db.Column(db.String(255), nullable=False) - mime_type = db.Column(db.String(255), nullable=True) - created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying")) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - used = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - used_by = db.Column(StringUUID, nullable=True) - used_at = db.Column(db.DateTime, nullable=True) - hash = db.Column(db.String(255), nullable=True) + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + storage_type: Mapped[str] = db.Column(db.String(255), nullable=False) + key: Mapped[str] = db.Column(db.String(255), nullable=False) + name: Mapped[str] = db.Column(db.String(255), nullable=False) + size: Mapped[int] = db.Column(db.Integer, nullable=False) + extension: Mapped[str] = db.Column(db.String(255), nullable=False) + mime_type: Mapped[str] = db.Column(db.String(255), nullable=True) + created_by_role: Mapped[str] = db.Column( + db.String(255), nullable=False, server_default=db.text("'account'::character varying") + ) + created_by: Mapped[str] = db.Column(StringUUID, nullable=False) + created_at: Mapped[datetime] = db.Column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) + used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) + hash: Mapped[str | None] = db.Column(db.String(255), nullable=True) + + def __init__( + self, + *, + tenant_id: str, + storage_type: str, + key: str, + name: str, + size: int, + extension: str, + mime_type: str, + created_by_role: str, + created_by: str, + created_at: datetime, + used: bool, + used_by: str | None = None, + used_at: datetime | None = None, + hash: str | None = None, + ) -> None: + self.tenant_id = tenant_id + self.storage_type = storage_type + self.key = key + self.name = name + self.size = size + self.extension = extension + self.mime_type = mime_type + self.created_by_role = created_by_role + self.created_by = created_by + self.created_at = created_at + self.used = used + self.used_by = used_by + self.used_at = used_at + self.hash = hash class ApiRequest(db.Model): diff --git a/api/models/tools.py b/api/models/tools.py index 861066a2d5..691f3f3cb6 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,8 @@ import json +from typing import Optional from sqlalchemy import ForeignKey +from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle @@ -101,7 +103,7 @@ class ApiToolProvider(db.Model): icon = db.Column(db.String(255), nullable=False) # original schema schema = db.Column(db.Text, nullable=False) - schema_type_str = db.Column(db.String(40), nullable=False) + schema_type_str: Mapped[str] = db.Column(db.String(40), nullable=False) # who created this tool user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -133,11 +135,11 @@ class ApiToolProvider(db.Model): return json.loads(self.credentials_str) @property - def user(self) -> Account: + def user(self) -> Account | None: return db.session.query(Account).filter(Account.id == self.user_id).first() @property - def tenant(self) -> Tenant: + def tenant(self) -> Tenant | None: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() @@ -203,11 +205,11 @@ class WorkflowToolProvider(db.Model): return ApiProviderSchemaType.value_of(self.schema_type_str) @property - def user(self) -> Account: + def user(self) -> Account | None: return db.session.query(Account).filter(Account.id == self.user_id).first() @property - def tenant(self) -> Tenant: + def tenant(self) -> Tenant | None: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() @property @@ -215,7 +217,7 @@ class WorkflowToolProvider(db.Model): return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)] @property - def app(self) -> App: + def app(self) -> App | None: return db.session.query(App).filter(App.id == self.app_id).first() @@ -288,27 +290,39 @@ class ToolConversationVariables(db.Model): class ToolFile(db.Model): - """ - store the file created by agent - """ - __tablename__ = "tool_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_file_pkey"), - # add index for conversation_id db.Index("tool_file_conversation_id_idx", "conversation_id"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - # conversation user id - user_id = db.Column(StringUUID, nullable=False) - # tenant id - tenant_id = db.Column(StringUUID, nullable=False) - # conversation id - conversation_id = db.Column(StringUUID, nullable=True) - # file key - file_key = db.Column(db.String(255), nullable=False) - # mime type - mimetype = db.Column(db.String(255), nullable=False) - # original url - original_url = db.Column(db.String(2048), nullable=True) + user_id: Mapped[str] = db.Column(StringUUID, nullable=False) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + conversation_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) + file_key: Mapped[str] = db.Column(db.String(255), nullable=False) + mimetype: Mapped[str] = db.Column(db.String(255), nullable=False) + original_url: Mapped[Optional[str]] = db.Column(db.String(2048), nullable=True) + name: Mapped[str] = mapped_column(default="") + size: Mapped[int] = mapped_column(default=-1) + + def __init__( + self, + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str] = None, + file_key: str, + mimetype: str, + original_url: Optional[str] = None, + name: str, + size: int, + ): + self.user_id = user_id + self.tenant_id = tenant_id + self.conversation_id = conversation_id + self.file_key = file_key + self.mimetype = mimetype + self.original_url = original_url + self.name = name + self.size = size diff --git a/api/models/workflow.py b/api/models/workflow.py index 9c93ea4cea..e5fbcaf87e 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -5,41 +5,21 @@ from enum import Enum from typing import Any, Optional, Union from sqlalchemy import func -from sqlalchemy.orm import Mapped +from sqlalchemy.orm import Mapped, mapped_column import contexts from constants import HIDDEN_VALUE -from core.app.segments import SecretVariable, Variable, factory from core.helper import encrypter +from core.variables import SecretVariable, Variable from extensions.ext_database import db +from factories import variable_factory from libs import helper +from models.enums import CreatedByRole from .account import Account from .types import StringUUID -class CreatedByRole(Enum): - """ - Created By Role Enum - """ - - ACCOUNT = "account" - END_USER = "end_user" - - @classmethod - def value_of(cls, value: str) -> "CreatedByRole": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid created by role value {value}") - - class WorkflowType(Enum): """ Workflow Type Enum @@ -114,23 +94,23 @@ class Workflow(db.Model): db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), ) - id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) - app_id: Mapped[str] = db.Column(StringUUID, nullable=False) - type: Mapped[str] = db.Column(db.String(255), nullable=False) - version: Mapped[str] = db.Column(db.String(255), nullable=False) - graph: Mapped[str] = db.Column(db.Text) - features: Mapped[str] = db.Column(db.Text) - created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + type: Mapped[str] = mapped_column(db.String(255), nullable=False) + version: Mapped[str] = mapped_column(db.String(255), nullable=False) + graph: Mapped[str] = mapped_column(db.Text) + _features: Mapped[str] = mapped_column("features") + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) - updated_by: Mapped[str] = db.Column(StringUUID) - updated_at: Mapped[datetime] = db.Column(db.DateTime) - _environment_variables: Mapped[str] = db.Column( + updated_by: Mapped[str] = mapped_column(StringUUID) + updated_at: Mapped[datetime] = mapped_column(db.DateTime) + _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" ) - _conversation_variables: Mapped[str] = db.Column( + _conversation_variables: Mapped[str] = mapped_column( "conversation_variables", db.Text, nullable=False, server_default="{}" ) @@ -169,6 +149,34 @@ class Workflow(db.Model): def graph_dict(self) -> Mapping[str, Any]: return json.loads(self.graph) if self.graph else {} + @property + def features(self) -> str: + """ + Convert old features structure to new features structure. + """ + if not self._features: + return self._features + + features = json.loads(self._features) + if features.get("file_upload", {}).get("image", {}).get("enabled", False): + image_enabled = True + image_number_limits = int(features["file_upload"]["image"].get("number_limits", 1)) + image_transfer_methods = features["file_upload"]["image"].get( + "transfer_methods", ["remote_url", "local_file"] + ) + features["file_upload"]["enabled"] = image_enabled + features["file_upload"]["number_limits"] = image_number_limits + features["file_upload"]["allowed_upload_methods"] = image_transfer_methods + features["file_upload"]["allowed_file_types"] = ["image"] + features["file_upload"]["allowed_extensions"] = [] + del features["file_upload"]["image"] + self._features = json.dumps(features) + return self._features + + @features.setter + def features(self, value: str) -> None: + self._features = value + @property def features_dict(self) -> Mapping[str, Any]: return json.loads(self.features) if self.features else {} @@ -227,7 +235,7 @@ class Workflow(db.Model): tenant_id = contexts.tenant_id.get() environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) - results = [factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()] + results = [variable_factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()] # decrypt secret variables value decrypt_func = ( @@ -240,6 +248,10 @@ class Workflow(db.Model): @environment_variables.setter def environment_variables(self, value: Sequence[Variable]): + if not value: + self._environment_variables = "{}" + return + tenant_id = contexts.tenant_id.get() value = list(value) @@ -288,7 +300,7 @@ class Workflow(db.Model): self._conversation_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._conversation_variables) - results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()] + results = [variable_factory.build_variable_from_mapping(v) for v in variables_dict.values()] return results @conversation_variables.setter @@ -299,28 +311,6 @@ class Workflow(db.Model): ) -class WorkflowRunTriggeredFrom(Enum): - """ - Workflow Run Triggered From Enum - """ - - DEBUGGING = "debugging" - APP_RUN = "app-run" - - @classmethod - def value_of(cls, value: str) -> "WorkflowRunTriggeredFrom": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid workflow run triggered from value {value}") - - class WorkflowRunStatus(Enum): """ Workflow Run Status Enum @@ -401,7 +391,7 @@ class WorkflowRun(db.Model): graph = db.Column(db.Text) inputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) - outputs = db.Column(db.Text) + outputs: Mapped[str] = db.Column(db.Text) error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) @@ -413,27 +403,27 @@ class WorkflowRun(db.Model): @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property def graph_dict(self): - return json.loads(self.graph) if self.graph else None + return json.loads(self.graph) if self.graph else {} @property - def inputs_dict(self): - return json.loads(self.inputs) if self.inputs else None + def inputs_dict(self) -> Mapping[str, Any]: + return json.loads(self.inputs) if self.inputs else {} @property - def outputs_dict(self): - return json.loads(self.outputs) if self.outputs else None + def outputs_dict(self) -> Mapping[str, Any]: + return json.loads(self.outputs) if self.outputs else {} @property def message(self) -> Optional["Message"]: @@ -640,14 +630,14 @@ class WorkflowNodeExecution(db.Model): @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property @@ -672,7 +662,7 @@ class WorkflowNodeExecution(db.Model): extras = {} if self.execution_metadata_dict: - from core.workflow.entities.node_entities import NodeType + from core.workflow.nodes import NodeType if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: tool_info = self.execution_metadata_dict["tool_info"] @@ -759,14 +749,14 @@ class WorkflowAppLog(db.Model): @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @@ -800,4 +790,4 @@ class ConversationVariable(db.Model): def to_variable(self) -> Variable: mapping = json.loads(self.data) - return factory.build_variable_from_mapping(mapping) + return variable_factory.build_variable_from_mapping(mapping) diff --git a/api/poetry.lock b/api/poetry.lock index 9c45899699..9239b1f887 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -5762,13 +5762,13 @@ sympy = "*" [[package]] name = "openai" -version = "1.51.2" +version = "1.52.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.51.2-py3-none-any.whl", hash = "sha256:5c5954711cba931423e471c37ff22ae0fd3892be9b083eee36459865fbbb83fa"}, - {file = "openai-1.51.2.tar.gz", hash = "sha256:c6a51fac62a1ca9df85a522e462918f6bb6bc51a8897032217e453a0730123a6"}, + {file = "openai-1.52.0-py3-none-any.whl", hash = "sha256:0c249f20920183b0a2ca4f7dba7b0452df3ecd0fa7985eb1d91ad884bc3ced9c"}, + {file = "openai-1.52.0.tar.gz", hash = "sha256:95c65a5f77559641ab8f3e4c3a050804f7b51d278870e2ec1f7444080bfe565a"}, ] [package.dependencies] @@ -7098,6 +7098,17 @@ typing-extensions = ">3.10,<4.6.0 || >4.6.0" [package.extras] dev = ["build", "coverage", "furo", "invoke", "mypy", "pytest", "pytest-cov", "pytest-mypy-testing", "ruff", "sphinx", "sphinx-autodoc-typehints", "tox", "twine", "wheel"] +[[package]] +name = "pydub" +version = "0.25.1" +description = "Manipulate audio with an simple and easy high level interface" +optional = false +python-versions = "*" +files = [ + {file = "pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6"}, + {file = "pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f"}, +] + [[package]] name = "pygments" version = "2.18.0" @@ -10784,4 +10795,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "51f048197baebf9ffdc393e5990b9a90185bc5ff515b8b5d2d9b72de900cf6e2" +content-hash = "642b2dae9e18ee6671d3d2c7129cb9a77327b69dacba996d00de2a9475d5bad3" diff --git a/api/pyproject.toml b/api/pyproject.toml index ec5266f926..c2c62ffecd 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -158,7 +158,7 @@ nomic = "~3.1.2" novita-client = "~0.5.7" numpy = "~1.26.4" oci = "~2.135.1" -openai = "~1.51.2" +openai = "~1.52.0" openpyxl = "~3.1.5" pandas = { version = "~2.2.2", extras = ["performance", "excel"] } psycopg2-binary = "~2.9.6" @@ -216,6 +216,7 @@ matplotlib = "~3.8.2" newspaper3k = "0.2.8" nltk = "3.8.1" numexpr = "~2.9.0" +pydub = "~0.25.1" qrcode = "~7.4.2" twilio = "~9.0.4" vanna = { version = "0.7.3", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 887fb878b9..c8819535f1 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -68,7 +68,7 @@ class AgentService: "iterations": len(agent_thoughts), }, "iterations": [], - "files": message.files, + "files": message.message_files, } agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 54594e1175..750d0a8cd2 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -3,9 +3,9 @@ import logging import httpx import yaml # type: ignore -from core.app.segments import factory from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_database import db +from factories import variable_factory from models.account import Account from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow @@ -254,14 +254,18 @@ class AppDslService: # init draft workflow environment_variables_list = workflow_data.get("environment_variables") or [] - environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + environment_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + ] conversation_variables_list = workflow_data.get("conversation_variables") or [] - conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] + conversation_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + ] workflow_service = WorkflowService() draft_workflow = workflow_service.sync_draft_workflow( app_model=app, graph=workflow_data.get("graph", {}), - features=workflow_data.get("../core/app/features", {}), + features=workflow_data.get("features", {}), unique_hash=None, account=account, environment_variables=environment_variables, @@ -295,9 +299,13 @@ class AppDslService: # sync draft workflow environment_variables_list = workflow_data.get("environment_variables") or [] - environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + environment_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + ] conversation_variables_list = workflow_data.get("conversation_variables") or [] - conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] + conversation_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + ] draft_workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=workflow_data.get("graph", {}), diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 26517a05fb..83a9a16904 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,4 +1,4 @@ -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Union from openai._exceptions import RateLimitError @@ -23,7 +23,7 @@ class AppGenerateService: cls, app_model: App, user: Union[Account, EndUser], - args: Any, + args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, ): diff --git a/api/services/file_service.py b/api/services/file_service.py index bedec76334..0b35561600 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -9,72 +9,55 @@ from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound from configs import dify_config -from core.file.upload_file_parser import UploadFileParser +from constants import ( + AUDIO_EXTENSIONS, + DOCUMENT_EXTENSIONS, + IMAGE_EXTENSIONS, + VIDEO_EXTENSIONS, +) +from core.file import helpers as file_helpers from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Account from models.model import EndUser, UploadFile -from services.errors.file import FileTooLargeError, UnsupportedFileTypeError - -IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] -IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) - -ALLOWED_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] -UNSTRUCTURED_ALLOWED_EXTENSIONS = [ - "txt", - "markdown", - "md", - "pdf", - "html", - "htm", - "xlsx", - "xls", - "docx", - "csv", - "eml", - "msg", - "pptx", - "ppt", - "xml", - "epub", -] +from services.errors.file import FileNotExistsError, FileTooLargeError, UnsupportedFileTypeError PREVIEW_WORDS_LIMIT = 3000 class FileService: @staticmethod - def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: + def upload_file(file: FileStorage, user: Union[Account, EndUser]) -> UploadFile: + # get file name filename = file.filename - extension = file.filename.split(".")[-1] + if not filename: + raise FileNotExistsError + extension = filename.split(".")[-1] if len(filename) > 200: filename = filename.split(".")[0][:200] + "." + extension - etl_type = dify_config.ETL_TYPE - allowed_extensions = ( - UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS - if etl_type == "Unstructured" - else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS - ) - if extension.lower() not in allowed_extensions or only_image and extension.lower() not in IMAGE_EXTENSIONS: - raise UnsupportedFileTypeError() - # read file content file_content = file.read() # get file size file_size = len(file_content) - if extension.lower() in IMAGE_EXTENSIONS: + # select file size limit + if extension in IMAGE_EXTENSIONS: file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in VIDEO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in AUDIO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 else: file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 + # check if the file size is exceeded if file_size > file_size_limit: message = f"File size exceeded. {file_size} > {file_size_limit}" raise FileTooLargeError(message) - # user uuid as file name + # generate file key file_uuid = str(uuid.uuid4()) if isinstance(user, Account): @@ -150,9 +133,7 @@ class FileService: # extract text from file extension = upload_file.extension - etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS - if extension.lower() not in allowed_extensions: + if extension.lower() not in DOCUMENT_EXTENSIONS: raise UnsupportedFileTypeError() text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) @@ -161,8 +142,10 @@ class FileService: return text @staticmethod - def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> tuple[Generator, str]: - result = UploadFileParser.verify_image_file_signature(file_id, timestamp, nonce, sign) + def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str): + result = file_helpers.verify_image_signature( + upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign + ) if not result: raise NotFound("File not found or signature is invalid") @@ -180,6 +163,21 @@ class FileService: return generator, upload_file.mime_type + @staticmethod + def get_signed_file_preview(file_id: str, timestamp: str, nonce: str, sign: str): + result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign) + if not result: + raise NotFound("File not found or signature is invalid") + + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + + if not upload_file: + raise NotFound("File not found or signature is invalid") + + generator = storage.load(upload_file.key, stream=True) + + return generator, upload_file.mime_type + @staticmethod def get_public_image_preview(file_id: str) -> tuple[Generator, str]: upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 5868ef3755..833881b668 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,6 +1,7 @@ import json +from collections.abc import Mapping from datetime import datetime -from typing import Optional +from typing import Any, Optional from sqlalchemy import or_ @@ -21,9 +22,9 @@ class WorkflowToolManageService: Service class for managing workflow tools. """ - @classmethod + @staticmethod def create_workflow_tool( - cls, + *, user_id: str, tenant_id: str, workflow_app_id: str, @@ -31,22 +32,10 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[dict], + parameters: Mapping[str, Any], privacy_policy: str = "", labels: Optional[list[str]] = None, ) -> dict: - """ - Create a workflow tool. - :param user_id: the user id - :param tenant_id: the tenant id - :param name: the name - :param icon: the icon - :param description: the description - :param parameters: the parameters - :param privacy_policy: the privacy policy - :param labels: labels - :return: the created tool - """ WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique @@ -63,12 +52,11 @@ class WorkflowToolManageService: if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") - app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() - + app = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() if app is None: raise ValueError(f"App {workflow_app_id} not found") - workflow: Workflow = app.workflow + workflow = app.workflow if workflow is None: raise ValueError(f"Workflow not found for app {workflow_app_id}") diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index db1a036e68..75c11afa94 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -13,12 +13,12 @@ from core.app.app_config.entities import ( from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from core.file.file_obj import FileExtraConfig +from core.file.models import FileExtraConfig from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes import NodeType from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account @@ -522,7 +522,7 @@ class WorkflowConverter: "vision": { "enabled": file_upload is not None, "variable_selector": ["sys", "files"] if file_upload is not None else None, - "configs": {"detail": file_upload.image_config["detail"]} + "configs": {"detail": file_upload.image_config.detail} if file_upload is not None and file_upload.image_config is not None else None, }, diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index b4f0882a3a..f89487415d 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -4,9 +4,9 @@ from flask_sqlalchemy.pagination import Pagination from sqlalchemy import and_, or_ from extensions.ext_database import db -from models import CreatedByRole -from models.model import App, EndUser -from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus +from models import App, EndUser, WorkflowAppLog, WorkflowRun +from models.enums import CreatedByRole +from models.workflow import WorkflowRunStatus class WorkflowAppService: @@ -21,7 +21,7 @@ class WorkflowAppService: WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id ) - status = WorkflowRunStatus.value_of(args.get("status")) if args.get("status") else None + status = WorkflowRunStatus.value_of(args.get("status", "")) if args.get("status") else None keyword = args["keyword"] if keyword or status: query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) @@ -42,7 +42,7 @@ class WorkflowAppService: query = query.outerjoin( EndUser, - and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value), + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER), ).filter(or_(*keyword_conditions)) if status: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index b7b3abeaa2..d8ee323908 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,11 +1,11 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.enums import WorkflowRunTriggeredFrom from models.model import App from models.workflow import ( WorkflowNodeExecution, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, - WorkflowRunTriggeredFrom, ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0ff81f1f7e..7187d40517 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -6,19 +6,20 @@ from typing import Optional from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.segments import Variable from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.variables import Variable +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.nodes import NodeType from core.workflow.nodes.event import RunCompletedEvent -from core.workflow.nodes.node_mapping import node_classes +from core.workflow.nodes.node_mapping import node_type_classes_mapping from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account +from models.enums import CreatedByRole from models.model import App, AppMode from models.workflow import ( - CreatedByRole, Workflow, WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -175,7 +176,7 @@ class WorkflowService: """ # return default block config default_block_configs = [] - for node_type, node_class in node_classes.items(): + for node_type, node_class in node_type_classes_mapping.items(): default_config = node_class.get_default_config() if default_config: default_block_configs.append(default_config) @@ -189,10 +190,10 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type_enum: NodeType = NodeType.value_of(node_type) + node_type_enum: NodeType = NodeType(node_type) # return default block config - node_class = node_classes.get(node_type_enum) + node_class = node_type_classes_mapping.get(node_type_enum) if not node_class: return None @@ -251,7 +252,7 @@ class WorkflowService: workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value workflow_node_execution.index = 1 workflow_node_execution.node_id = node_id - workflow_node_execution.node_type = node_instance.node_type.value + workflow_node_execution.node_type = node_instance.node_type workflow_node_execution.title = node_instance.node_data.title workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index ec013183b7..f08d270b4b 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -6,6 +6,8 @@ import httpx import pytest from _pytest.monkeypatch import MonkeyPatch +from core.helper import ssrf_proxy + MOCK = os.getenv("MOCK_SWITCH", "false") == "true" @@ -24,10 +26,16 @@ class MockedHttp: # get data, files data = kwargs.get("data") files = kwargs.get("files") + json = kwargs.get("json") + content = kwargs.get("content") if data is not None: resp = dumps(data).encode("utf-8") elif files is not None: resp = dumps(files).encode("utf-8") + elif json is not None: + resp = dumps(json).encode("utf-8") + elif content is not None: + resp = content else: resp = b"OK" @@ -43,6 +51,6 @@ def setup_http_mock(request, monkeypatch: MonkeyPatch): yield return - monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request) + monkeypatch.setattr(ssrf_proxy, "make_request", MockedHttp.httpx_request) yield monkeypatch.undo() diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 952c90674d..fd0f25cf04 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -6,7 +6,7 @@ from typing import cast import pytest from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import NodeRunResult, UserFrom +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph @@ -14,6 +14,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.entities import CodeNodeData +from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 65aaa0bddd..9eea63f722 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -5,13 +5,13 @@ from urllib.parse import urlencode import pytest from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from core.workflow.nodes.http_request.node import HttpRequestNode +from models.enums import UserFrom from models.workflow import WorkflowType from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock @@ -211,7 +211,16 @@ def test_json(setup_http_mock): }, "headers": "X-Header:123", "params": "A:b", - "body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'}, + "body": { + "type": "json", + "data": [ + { + "key": "", + "type": "text", + "value": '{"a": "{{#a.b123.args1#}}"}', + }, + ], + }, }, } ) @@ -243,7 +252,21 @@ def test_x_www_form_urlencoded(setup_http_mock): }, "headers": "X-Header:123", "params": "A:b", - "body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, + "body": { + "type": "x-www-form-urlencoded", + "data": [ + { + "key": "a", + "type": "text", + "value": "{{#a.b123.args1#}}", + }, + { + "key": "b", + "type": "text", + "value": "{{#a.b123.args2#}}", + }, + ], + }, }, } ) @@ -275,7 +298,21 @@ def test_form_data(setup_http_mock): }, "headers": "X-Header:123", "params": "A:b", - "body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, + "body": { + "type": "form-data", + "data": [ + { + "key": "a", + "type": "text", + "value": "{{#a.b123.args1#}}", + }, + { + "key": "b", + "type": "text", + "value": "{{#a.b123.args2#}}", + }, + ], + }, }, } ) @@ -310,7 +347,7 @@ def test_none_data(setup_http_mock): }, "headers": "X-Header:123", "params": "A:b", - "body": {"type": "none", "data": "123123123"}, + "body": {"type": "none", "data": []}, }, } ) @@ -366,7 +403,21 @@ def test_multi_colons_parse(setup_http_mock): }, "params": "Referer:http://example1.com\nRedirect:http://example2.com", "headers": "Referer:http://example3.com\nRedirect:http://example4.com", - "body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"}, + "body": { + "type": "form-data", + "data": [ + { + "key": "Referer", + "type": "text", + "value": "http://example5.com", + }, + { + "key": "Redirect", + "type": "text", + "value": "http://example6.com", + }, + ], + }, }, } ) @@ -377,5 +428,5 @@ def test_multi_colons_parse(setup_http_mock): resp = result.outputs assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "") - assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request", "") - assert "http://example3.com" == resp.get("headers", {}).get("referer") + assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "") + # assert "http://example3.com" == resp.get("headers", {}).get("referer") diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index dfb43650d2..9a23949b38 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -13,15 +13,15 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory -from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event import RunCompletedEvent -from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db +from models.enums import UserFrom from models.provider import ProviderType from models.workflow import WorkflowNodeExecutionStatus, WorkflowType diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 4c695f7443..42a058d29b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -12,7 +12,6 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph @@ -20,6 +19,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db +from models.enums import UserFrom from models.provider import ProviderType """FOR MOCK FIXTURES, DO NOT REMOVE""" diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 073c4bb799..51d61a95ea 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -4,13 +4,13 @@ import uuid import pytest from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 4d94cdb28a..4068e796b7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,13 +2,14 @@ import time import uuid from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import NodeRunResult, UserFrom +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.tool.tool_node import ToolNode +from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType diff --git a/api/tests/integration_tests/workflow/test_sync_workflow.py b/api/tests/integration_tests/workflow/test_sync_workflow.py new file mode 100644 index 0000000000..df2ec95ebc --- /dev/null +++ b/api/tests/integration_tests/workflow/test_sync_workflow.py @@ -0,0 +1,57 @@ +""" +This test file is used to verify the compatibility of Workflow before and after supporting multiple file types. +""" + +import json + +from models import Workflow + +OLD_VERSION_WORKFLOW_FEATURES = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + +NEW_VERSION_WORKFLOW_FEATURES = { + "file_upload": { + "enabled": True, + "allowed_file_types": ["image"], + "allowed_extensions": [], + "allowed_upload_methods": ["remote_url", "local_file"], + "number_limits": 6, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + + +def test_workflow_features(): + workflow = Workflow( + tenant_id="", + app_id="", + type="", + version="", + graph="", + features=json.dumps(OLD_VERSION_WORKFLOW_FEATURES), + created_by="", + environment_variables=[], + conversation_variables=[], + ) + + assert workflow.features_dict == NEW_VERSION_WORKFLOW_FEATURES diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index 0824c8e9e9..72d277fad4 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -2,7 +2,7 @@ from uuid import uuid4 import pytest -from core.app.segments import ( +from core.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -11,43 +11,43 @@ from core.app.segments import ( ObjectSegment, SecretVariable, StringVariable, - factory, ) -from core.app.segments.exc import VariableError +from core.variables.exc import VariableError +from factories import variable_factory def test_string_variable(): test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} - result = factory.build_variable_from_mapping(test_data) + result = variable_factory.build_variable_from_mapping(test_data) assert isinstance(result, StringVariable) def test_integer_variable(): test_data = {"value_type": "number", "name": "test_int", "value": 42} - result = factory.build_variable_from_mapping(test_data) + result = variable_factory.build_variable_from_mapping(test_data) assert isinstance(result, IntegerVariable) def test_float_variable(): test_data = {"value_type": "number", "name": "test_float", "value": 3.14} - result = factory.build_variable_from_mapping(test_data) + result = variable_factory.build_variable_from_mapping(test_data) assert isinstance(result, FloatVariable) def test_secret_variable(): test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} - result = factory.build_variable_from_mapping(test_data) + result = variable_factory.build_variable_from_mapping(test_data) assert isinstance(result, SecretVariable) def test_invalid_value_type(): test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} with pytest.raises(VariableError): - factory.build_variable_from_mapping(test_data) + variable_factory.build_variable_from_mapping(test_data) def test_build_a_blank_string(): - result = factory.build_variable_from_mapping( + result = variable_factory.build_variable_from_mapping( { "value_type": "string", "name": "blank", @@ -59,7 +59,7 @@ def test_build_a_blank_string(): def test_build_a_object_variable_with_none_value(): - var = factory.build_segment( + var = variable_factory.build_segment( { "key1": None, } @@ -79,7 +79,7 @@ def test_object_variable(): "key2": 2, }, } - variable = factory.build_variable_from_mapping(mapping) + variable = variable_factory.build_variable_from_mapping(mapping) assert isinstance(variable, ObjectSegment) assert isinstance(variable.value["key1"], str) assert isinstance(variable.value["key2"], int) @@ -96,7 +96,7 @@ def test_array_string_variable(): "text", ], } - variable = factory.build_variable_from_mapping(mapping) + variable = variable_factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayStringVariable) assert isinstance(variable.value[0], str) assert isinstance(variable.value[1], str) @@ -113,7 +113,7 @@ def test_array_number_variable(): 2.0, ], } - variable = factory.build_variable_from_mapping(mapping) + variable = variable_factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayNumberVariable) assert isinstance(variable.value[0], int) assert isinstance(variable.value[1], float) @@ -136,7 +136,7 @@ def test_array_object_variable(): }, ], } - variable = factory.build_variable_from_mapping(mapping) + variable = variable_factory.build_variable_from_mapping(mapping) assert isinstance(variable, ArrayObjectVariable) assert isinstance(variable.value[0], dict) assert isinstance(variable.value[1], dict) @@ -146,13 +146,13 @@ def test_array_object_variable(): assert isinstance(variable.value[1]["key2"], int) -def test_variable_cannot_large_than_5_kb(): +def test_variable_cannot_large_than_200_kb(): with pytest.raises(VariableError): - factory.build_variable_from_mapping( + variable_factory.build_variable_from_mapping( { "id": str(uuid4()), "value_type": "string", "name": "test_text", - "value": "a" * 1024 * 6, + "value": "a" * 1024 * 201, } ) diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 73002623f0..3b1715ab45 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -1,5 +1,5 @@ -from core.app.segments import SecretVariable, StringSegment, parser from core.helper import encrypter +from core.variables import SecretVariable, StringSegment from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey @@ -13,12 +13,13 @@ def test_segment_group_to_text(): environment_variables=[ SecretVariable(name="secret_key", value="fake-secret-key"), ], + conversation_variables=[], ) variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." ) - segments_group = parser.convert_template(template=template, variable_pool=variable_pool) + segments_group = variable_pool.convert_template(template) assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key." assert segments_group.log == ( @@ -32,9 +33,10 @@ def test_convert_constant_to_segment_group(): system_variables={}, user_inputs={}, environment_variables=[], + conversation_variables=[], ) template = "Hello, world!" - segments_group = parser.convert_template(template=template, variable_pool=variable_pool) + segments_group = variable_pool.convert_template(template) assert segments_group.text == "Hello, world!" assert segments_group.log == "Hello, world!" @@ -46,9 +48,10 @@ def test_convert_variable_to_segment_group(): }, user_inputs={}, environment_variables=[], + conversation_variables=[], ) template = "{{#sys.user_id#}}" - segments_group = parser.convert_template(template=template, variable_pool=variable_pool) + segments_group = variable_pool.convert_template(template) assert segments_group.text == "fake-user-id" assert segments_group.log == "fake-user-id" assert segments_group.value == [StringSegment(value="fake-user-id")] diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/app/segments/test_variables.py index 6179675cde..0c264c15a0 100644 --- a/api/tests/unit_tests/core/app/segments/test_variables.py +++ b/api/tests/unit_tests/core/app/segments/test_variables.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from core.app.segments import ( +from core.variables import ( FloatVariable, IntegerVariable, ObjectVariable, diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index d6e6b0b79c..c688d3952b 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -6,7 +6,7 @@ import pytest from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request -@patch("httpx.request") +@patch("httpx.Client.request") def test_successful_request(mock_request): mock_response = MagicMock() mock_response.status_code = 200 @@ -16,7 +16,7 @@ def test_successful_request(mock_request): assert response.status_code == 200 -@patch("httpx.request") +@patch("httpx.Client.request") def test_retry_exceed_max_retries(mock_request): mock_response = MagicMock() mock_response.status_code = 500 @@ -29,7 +29,7 @@ def test_retry_exceed_max_retries(mock_request): assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" -@patch("httpx.request") +@patch("httpx.Client.request") def test_retry_logic_success(mock_request): side_effects = [] diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 24b338601d..ece2173090 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -1,11 +1,16 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from core.app.app_config.entities import ModelConfigEntity -from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar +from core.file import File, FileExtraConfig, FileTransferMethod, FileType, ImageConfig from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + UserPromptMessage, +) from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_template_parser import PromptTemplateParser @@ -123,32 +128,30 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg model_config_mock, _, messages, inputs, context = get_chat_model_args files = [ - FileVar( + File( id="file1", tenant_id="tenant1", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, - url="https://example.com/image1.jpg", - extra_config=FileExtraConfig( - image_config={ - "detail": "high", - } - ), + remote_url="https://example.com/image1.jpg", + _extra_config=FileExtraConfig(image_config=ImageConfig(detail=ImagePromptMessageContent.DETAIL.HIGH)), ) ] prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) - prompt_messages = prompt_transform._get_chat_model_prompt_messages( - prompt_template=messages, - inputs=inputs, - query=None, - files=files, - context=context, - memory_config=None, - memory=None, - model_config=model_config_mock, - ) + with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string: + mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url)) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template=messages, + inputs=inputs, + query=None, + files=files, + context=context, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) assert len(prompt_messages) == 4 assert prompt_messages[0].role == PromptMessageRole.SYSTEM @@ -157,7 +160,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg ) assert isinstance(prompt_messages[3].content, list) assert len(prompt_messages[3].content) == 2 - assert prompt_messages[3].content[1].data == files[0].url + assert prompt_messages[3].content[1].data == files[0].remote_url @pytest.fixture diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py new file mode 100644 index 0000000000..aa61c1c6f7 --- /dev/null +++ b/api/tests/unit_tests/core/test_file.py @@ -0,0 +1,40 @@ +from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType + + +def test_file_loads_and_dumps(): + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image1.jpg", + ) + + file_dict = file.model_dump() + assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY + assert file_dict["type"] == file.type.value + assert isinstance(file_dict["type"], str) + assert file_dict["transfer_method"] == file.transfer_method.value + assert isinstance(file_dict["transfer_method"], str) + assert "_extra_config" not in file_dict + + file_obj = File.model_validate(file_dict) + assert file_obj.id == file.id + assert file_obj.tenant_id == file.tenant_id + assert file_obj.type == file.type + assert file_obj.transfer_method == file.transfer_method + assert file_obj.remote_url == file.remote_url + + +def test_file_to_dict(): + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image1.jpg", + ) + + file_dict = file.to_dict() + assert "_extra_config" not in file_dict + assert "url" in file_dict diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py deleted file mode 100644 index 279a6cdbc3..0000000000 --- a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest - -from core.tools.entities.tool_entities import ToolParameter -from core.tools.utils.tool_parameter_converter import ToolParameterConverter - - -def test_get_parameter_type(): - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string" - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string" - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean" - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number" - with pytest.raises(ValueError): - ToolParameterConverter.get_parameter_type("unsupported_type") - - -def test_cast_parameter_by_type(): - # string - assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test" - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1" - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0" - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == "" - - # secret input - assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test" - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1" - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0" - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == "" - - # select - assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test" - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1" - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0" - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == "" - - # boolean - true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] - for value in true_values: - assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True - - false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""] - for value in false_values: - assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False - - # number - assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1 - assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0 - assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0 - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1 - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0 - assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0 - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None - - # unknown - assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1" - assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1" - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_type.py b/api/tests/unit_tests/core/tools/test_tool_parameter_type.py new file mode 100644 index 0000000000..8a41678267 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_type.py @@ -0,0 +1,49 @@ +from core.tools.entities.tool_entities import ToolParameter + + +def test_get_parameter_type(): + assert ToolParameter.ToolParameterType.STRING.as_normal_type() == "string" + assert ToolParameter.ToolParameterType.SELECT.as_normal_type() == "string" + assert ToolParameter.ToolParameterType.SECRET_INPUT.as_normal_type() == "string" + assert ToolParameter.ToolParameterType.BOOLEAN.as_normal_type() == "boolean" + assert ToolParameter.ToolParameterType.NUMBER.as_normal_type() == "number" + assert ToolParameter.ToolParameterType.FILE.as_normal_type() == "file" + assert ToolParameter.ToolParameterType.FILES.as_normal_type() == "files" + + +def test_cast_parameter_by_type(): + # string + assert ToolParameter.ToolParameterType.STRING.cast_value("test") == "test" + assert ToolParameter.ToolParameterType.STRING.cast_value(1) == "1" + assert ToolParameter.ToolParameterType.STRING.cast_value(1.0) == "1.0" + assert ToolParameter.ToolParameterType.STRING.cast_value(None) == "" + + # secret input + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value("test") == "test" + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1) == "1" + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1.0) == "1.0" + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(None) == "" + + # select + assert ToolParameter.ToolParameterType.SELECT.cast_value("test") == "test" + assert ToolParameter.ToolParameterType.SELECT.cast_value(1) == "1" + assert ToolParameter.ToolParameterType.SELECT.cast_value(1.0) == "1.0" + assert ToolParameter.ToolParameterType.SELECT.cast_value(None) == "" + + # boolean + true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] + for value in true_values: + assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is True + + false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""] + for value in false_values: + assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is False + + # number + assert ToolParameter.ToolParameterType.NUMBER.cast_value("1") == 1 + assert ToolParameter.ToolParameterType.NUMBER.cast_value("1.0") == 1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value("-1.0") == -1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(1) == 1 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(1.0) == 1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(-1.0) == -1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(None) is None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 197288adba..9f1ba7b6af 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -1,7 +1,7 @@ from unittest.mock import patch from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( @@ -18,7 +18,8 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent -from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.llm.node import LLMNode +from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @@ -86,7 +87,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): {"role": "system", "text": "say hi"}, {"role": "user", "text": "{{#start.query#}}"}, ], - "vision": {"configs": {"detail": "high"}, "enabled": False}, + "vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False}, }, "id": "llm1", }, @@ -105,7 +106,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): {"role": "system", "text": "say bye"}, {"role": "user", "text": "{{#start.query#}}"}, ], - "vision": {"configs": {"detail": "high"}, "enabled": False}, + "vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False}, }, "id": "llm2", }, @@ -124,7 +125,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): {"role": "system", "text": "say good morning"}, {"role": "user", "text": "{{#start.query#}}"}, ], - "vision": {"configs": {"detail": "high"}, "enabled": False}, + "vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False}, }, "id": "llm3", }, diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index fe4ede6335..0369f3fa44 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -3,7 +3,6 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph @@ -11,6 +10,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.answer.answer_node import AnswerNode from extensions.ext_database import db +from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index 6b1d1e9070..f6b3be8250 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -2,7 +2,6 @@ import uuid from collections.abc import Generator from datetime import datetime, timezone -from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( @@ -14,6 +13,7 @@ from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.start.entities import StartNodeData @@ -39,7 +39,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve node_execution_id = str(uuid.uuid4()) node_config = graph.node_id_config_mapping[next_node_id] - node_type = NodeType.value_of(node_config.get("data", {}).get("type")) + node_type = NodeType(node_config.get("data", {}).get("type")) mock_node_data = StartNodeData(**{"title": "demo", "variables": []}) yield NodeRunStartedEvent( diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index b3a89061b2..d755faee8a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -3,7 +3,7 @@ import uuid from unittest.mock import patch from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import NodeRunResult, UserFrom +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph @@ -12,6 +12,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index cb2e99a854..2f0aa28b48 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -3,7 +3,6 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph @@ -11,6 +10,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.answer.answer_node import AnswerNode from extensions.ext_database import db +from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py new file mode 100644 index 0000000000..7471e13e1e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -0,0 +1,158 @@ +from unittest.mock import Mock, patch + +import pytest + +from core.file import File, FileTransferMethod +from core.variables import ArrayFileSegment +from core.variables.variables import StringVariable +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from core.workflow.nodes.document_extractor.node import ( + _extract_text_from_doc, + _extract_text_from_pdf, + _extract_text_from_plain_text, +) +from core.workflow.nodes.enums import NodeType +from models.workflow import WorkflowNodeExecutionStatus + + +@pytest.fixture +def document_extractor_node(): + node_data = DocumentExtractorNodeData( + title="Test Document Extractor", + variable_selector=["node_id", "variable_name"], + ) + return DocumentExtractorNode( + id="test_node_id", + config={"id": "test_node_id", "data": node_data.model_dump()}, + graph_init_params=Mock(), + graph=Mock(), + graph_runtime_state=Mock(), + ) + + +@pytest.fixture +def mock_graph_runtime_state(): + return Mock() + + +def test_run_variable_not_found(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + mock_graph_runtime_state.variable_pool.get.return_value = None + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error is not None + assert "File variable not found" in result.error + + +def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + mock_graph_runtime_state.variable_pool.get.return_value = StringVariable( + value="Not an ArrayFileSegment", name="test" + ) + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error is not None + assert "is not an ArrayFileSegment" in result.error + + +@pytest.mark.parametrize( + ("mime_type", "file_content", "expected_text", "transfer_method"), + [ + ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE), + ("application/pdf", b"%PDF-1.5\n%Test PDF content", ["Mocked PDF content"], FileTransferMethod.LOCAL_FILE), + ( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + b"PK\x03\x04", + ["Mocked DOCX content"], + FileTransferMethod.LOCAL_FILE, + ), + ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL), + ], +) +def test_run_extract_text( + document_extractor_node, + mock_graph_runtime_state, + mime_type, + file_content, + expected_text, + transfer_method, + monkeypatch, +): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + + mock_file = Mock(spec=File) + mock_file.mime_type = mime_type + mock_file.transfer_method = transfer_method + mock_file.related_id = "test_file_id" if transfer_method == FileTransferMethod.LOCAL_FILE else None + mock_file.remote_url = "https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None + + mock_array_file_segment = Mock(spec=ArrayFileSegment) + mock_array_file_segment.value = [mock_file] + + mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment + + mock_download = Mock(return_value=file_content) + mock_ssrf_proxy_get = Mock() + mock_ssrf_proxy_get.return_value.content = file_content + mock_ssrf_proxy_get.return_value.raise_for_status = Mock() + + monkeypatch.setattr("core.file.file_manager.download", mock_download) + monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get) + + if mime_type == "application/pdf": + mock_pdf_extract = Mock(return_value=expected_text[0]) + monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + elif mime_type.startswith("application/vnd.openxmlformats"): + mock_docx_extract = Mock(return_value=expected_text[0]) + monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_doc", mock_docx_extract) + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["text"] == expected_text + + if transfer_method == FileTransferMethod.REMOTE_URL: + mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt") + elif transfer_method == FileTransferMethod.LOCAL_FILE: + mock_download.assert_called_once_with(mock_file) + + +def test_extract_text_from_plain_text(): + text = _extract_text_from_plain_text(b"Hello, world!") + assert text == "Hello, world!" + + +@patch("pypdfium2.PdfDocument") +def test_extract_text_from_pdf(mock_pdf_document): + mock_page = Mock() + mock_text_page = Mock() + mock_text_page.get_text_range.return_value = "PDF content" + mock_page.get_textpage.return_value = mock_text_page + mock_pdf_document.return_value = [mock_page] + text = _extract_text_from_pdf(b"%PDF-1.5\n%Test PDF content") + assert text == "PDF content" + + +@patch("docx.Document") +def test_extract_text_from_doc(mock_document): + mock_paragraph1 = Mock() + mock_paragraph1.text = "Paragraph 1" + mock_paragraph2 = Mock() + mock_paragraph2.text = "Paragraph 2" + mock_document.return_value.paragraphs = [mock_paragraph1, mock_paragraph2] + + text = _extract_text_from_doc(b"PK\x03\x04") + assert text == "Paragraph 1\nParagraph 2" + + +def test_node_type(document_extractor_node): + assert document_extractor_node._node_type == NodeType.DOCUMENT_EXTRACTOR diff --git a/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py new file mode 100644 index 0000000000..28ecdaadb0 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py @@ -0,0 +1,202 @@ +import httpx + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file import File, FileTransferMethod, FileType +from core.variables import FileVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState +from core.workflow.nodes.answer import AnswerStreamGenerateRoute +from core.workflow.nodes.end import EndStreamParam +from core.workflow.nodes.http_request import ( + BodyData, + HttpRequestNode, + HttpRequestNodeAuthorization, + HttpRequestNodeBody, + HttpRequestNodeData, +) +from core.workflow.nodes.http_request.executor import _plain_text_to_dict +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_plain_text_to_dict(): + assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""} + assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"} + assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"} + assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {"aa": "bb", "cc": "dd"} + + +def test_http_request_node_binary_file(monkeypatch): + data = HttpRequestNodeData( + title="test", + method="post", + url="http://example.org/post", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="", + params="", + body=HttpRequestNodeBody( + type="binary", + data=[ + BodyData( + key="file", + type="file", + value="", + file=["1111", "file"], + ) + ], + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add( + ["1111", "file"], + FileVariable( + name="file", + value=File( + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1111", + ), + ), + ) + node = HttpRequestNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + monkeypatch.setattr( + "core.workflow.nodes.http_request.executor.file_manager.download", + lambda *args, **kwargs: b"test", + ) + monkeypatch.setattr( + "core.helper.ssrf_proxy.post", + lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]), + ) + result = node._run() + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["body"] == "test" + + +def test_http_request_node_form_with_file(monkeypatch): + data = HttpRequestNodeData( + title="test", + method="post", + url="http://example.org/post", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="", + params="", + body=HttpRequestNodeBody( + type="form-data", + data=[ + BodyData( + key="file", + type="file", + file=["1111", "file"], + ), + BodyData( + key="name", + type="text", + value="test", + ), + ], + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add( + ["1111", "file"], + FileVariable( + name="file", + value=File( + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1111", + ), + ), + ) + node = HttpRequestNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + monkeypatch.setattr( + "core.workflow.nodes.http_request.executor.file_manager.download", + lambda *args, **kwargs: b"test", + ) + + def attr_checker(*args, **kwargs): + assert kwargs["data"] == {"name": "test"} + assert kwargs["files"] == {"file": b"test"} + return httpx.Response(200, content=b"") + + monkeypatch.setattr( + "core.helper.ssrf_proxy.post", + attr_checker, + ) + result = node._run() + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["body"] == "" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 0795f134d0..8f38d3f280 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,16 +1,20 @@ import time import uuid -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import UserFrom +from core.file import File, FileTransferMethod, FileType +from core.variables import ArrayFileSegment from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition from extensions.ext_database import db +from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @@ -111,6 +115,7 @@ def test_execute_if_else_result_true(): result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs["result"] is True @@ -191,4 +196,63 @@ def test_execute_if_else_result_false(): result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs["result"] is False + + +def test_array_file_contains_file_name(): + node_data = IfElseNodeData( + title="123", + logical_operator="and", + cases=[ + IfElseNodeData.Case( + case_id="true", + logical_operator="and", + conditions=[ + Condition( + comparison_operator="contains", + variable_selector=["start", "array_contains"], + sub_variable_condition=SubVariableCondition( + logical_operator="and", + conditions=[ + SubCondition( + key="name", + comparison_operator="contains", + value="ab", + ) + ], + ), + ) + ], + ) + ], + ) + + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=Mock(), + graph=Mock(), + graph_runtime_state=Mock(), + config={ + "id": "if-else", + "data": node_data.model_dump(), + }, + ) + + node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( + value=[ + File( + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + filename="ab", + ), + ], + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["result"] is True diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py new file mode 100644 index 0000000000..53e3c93fcc --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -0,0 +1,111 @@ +from unittest.mock import MagicMock + +import pytest + +from core.file import File +from core.file.models import FileTransferMethod, FileType +from core.variables import ArrayFileSegment +from core.workflow.nodes.list_operator.entities import FilterBy, FilterCondition, Limit, ListOperatorNodeData, OrderBy +from core.workflow.nodes.list_operator.node import ListOperatorNode +from models.workflow import WorkflowNodeExecutionStatus + + +@pytest.fixture +def list_operator_node(): + config = { + "variable": ["test_variable"], + "filter_by": FilterBy( + enabled=True, + conditions=[ + FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT]) + ], + ), + "order_by": OrderBy(enabled=False, value="asc"), + "limit": Limit(enabled=False, size=0), + "title": "Test Title", + } + node_data = ListOperatorNodeData(**config) + node = ListOperatorNode( + id="test_node_id", + config={ + "id": "test_node_id", + "data": node_data.model_dump(), + }, + graph_init_params=MagicMock(), + graph=MagicMock(), + graph_runtime_state=MagicMock(), + ) + node.graph_runtime_state = MagicMock() + node.graph_runtime_state.variable_pool = MagicMock() + return node + + +def test_filter_files_by_type(list_operator_node): + # Setup test data + files = [ + File( + filename="image1.jpg", + type=FileType.IMAGE, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related1", + ), + File( + filename="document1.pdf", + type=FileType.DOCUMENT, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related2", + ), + File( + filename="image2.png", + type=FileType.IMAGE, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related3", + ), + File( + filename="audio1.mp3", + type=FileType.AUDIO, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related4", + ), + ] + variable = ArrayFileSegment(value=files) + list_operator_node.graph_runtime_state.variable_pool.get.return_value = variable + + # Run the node + result = list_operator_node._run() + + # Verify the result + expected_files = [ + { + "filename": "image1.jpg", + "type": FileType.IMAGE, + "tenant_id": "tenant1", + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "related1", + }, + { + "filename": "document1.pdf", + "type": FileType.DOCUMENT, + "tenant_id": "tenant1", + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "related2", + }, + { + "filename": "image2.png", + "type": FileType.IMAGE, + "tenant_id": "tenant1", + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "related3", + }, + ] + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + for expected_file, result_file in zip(expected_files, result.outputs["result"]): + assert expected_file["filename"] == result_file.filename + assert expected_file["type"] == result_file.type + assert expected_file["tenant_id"] == result_file.tenant_id + assert expected_file["transfer_method"] == result_file.transfer_method + assert expected_file["related_id"] == result_file.related_id diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py new file mode 100644 index 0000000000..f990280c5f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -0,0 +1,67 @@ +from core.model_runtime.entities import ImagePromptMessageContent +from core.workflow.nodes.question_classifier import QuestionClassifierNodeData + + +def test_init_question_classifier_node_data(): + data = { + "title": "test classifier node", + "query_variable_selector": ["id", "name"], + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, + "classes": [{"id": "1", "name": "class 1"}], + "instruction": "This is a test instruction", + "memory": { + "role_prefix": {"user": "Human:", "assistant": "AI:"}, + "window": {"enabled": True, "size": 5}, + "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", + }, + "vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}}, + } + + node_data = QuestionClassifierNodeData(**data) + + assert node_data.query_variable_selector == ["id", "name"] + assert node_data.model.provider == "openai" + assert node_data.classes[0].id == "1" + assert node_data.instruction == "This is a test instruction" + assert node_data.memory is not None + assert node_data.memory.role_prefix is not None + assert node_data.memory.role_prefix.user == "Human:" + assert node_data.memory.role_prefix.assistant == "AI:" + assert node_data.memory.window.enabled == True + assert node_data.memory.window.size == 5 + assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" + assert node_data.vision.enabled == True + assert node_data.vision.configs.variable_selector == ["image"] + assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW + + +def test_init_question_classifier_node_data_without_vision_config(): + data = { + "title": "test classifier node", + "query_variable_selector": ["id", "name"], + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, + "classes": [{"id": "1", "name": "class 1"}], + "instruction": "This is a test instruction", + "memory": { + "role_prefix": {"user": "Human:", "assistant": "AI:"}, + "window": {"enabled": True, "size": 5}, + "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", + }, + } + + node_data = QuestionClassifierNodeData(**data) + + assert node_data.query_variable_selector == ["id", "name"] + assert node_data.model.provider == "openai" + assert node_data.classes[0].id == "1" + assert node_data.instruction == "This is a test instruction" + assert node_data.memory is not None + assert node_data.memory.role_prefix is not None + assert node_data.memory.role_prefix.user == "Human:" + assert node_data.memory.role_prefix.assistant == "AI:" + assert node_data.memory.window.enabled == True + assert node_data.memory.window.size == 5 + assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" + assert node_data.vision.enabled == False + assert node_data.vision.configs.variable_selector == ["sys", "files"] + assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index f45a93f1be..096ae0ea52 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -4,14 +4,14 @@ from unittest import mock from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.segments import ArrayStringVariable, StringVariable -from core.workflow.entities.node_entities import UserFrom +from core.variables import ArrayStringVariable, StringVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode +from models.enums import UserFrom from models.workflow import WorkflowType DEFAULT_NODE_ID = "node_id" diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py new file mode 100644 index 0000000000..a1e4dda627 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -0,0 +1,45 @@ +import pytest + +from core.file import File, FileTransferMethod, FileType +from core.variables import FileSegment, StringSegment +from core.workflow.entities.variable_pool import VariablePool + + +@pytest.fixture +def pool(): + return VariablePool(system_variables={}, user_inputs={}) + + +@pytest.fixture +def file(): + return File( + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_related_id", + remote_url="test_url", + filename="test_file.txt", + ) + + +def test_get_file_attribute(pool, file): + # Add a FileSegment to the pool + pool.add(("node_1", "file_var"), FileSegment(value=file)) + + # Test getting the 'name' attribute of the file + result = pool.get(("node_1", "file_var", "name")) + + assert result is not None + assert result.value == file.filename + + # Test getting a non-existent attribute + with pytest.raises(ValueError): + pool.get(("node_1", "file_var", "non_existent_attr")) + + +def test_use_long_selector(pool): + pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value")) + + result = pool.get(("node_1", "part_1", "part_2")) + assert result is not None + assert result.value == "test_value" diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py new file mode 100644 index 0000000000..2f90afcf89 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py @@ -0,0 +1,28 @@ +from core.variables import SecretVariable +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.utils import variable_template_parser + + +def test_extract_selectors_from_template(): + variable_pool = VariablePool( + system_variables={ + SystemVariableKey("user_id"): "fake-user-id", + }, + user_inputs={}, + environment_variables=[ + SecretVariable(name="secret_key", value="fake-secret-key"), + ], + conversation_variables=[], + ) + variable_pool.add(("node_id", "custom_query"), "fake-user-query") + template = ( + "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." + ) + selectors = variable_template_parser.extract_selectors_from_template(template) + assert selectors == [ + VariableSelector(variable="#sys.user_id#", value_selector=["sys", "user_id"]), + VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]), + VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]), + ] diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index 7968347dec..b879afa3e7 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -1,11 +1,12 @@ from uuid import uuid4 -from core.app.segments import SegmentType, factory +from core.variables import SegmentType +from factories import variable_factory from models import ConversationVariable def test_from_variable_and_to_variable(): - variable = factory.build_variable_from_mapping( + variable = variable_factory.build_variable_from_mapping( { "id": str(uuid4()), "name": "name", diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 40483d7e3a..478fa8012b 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -3,7 +3,7 @@ from uuid import uuid4 import contexts from constants import HIDDEN_VALUE -from core.app.segments import FloatVariable, IntegerVariable, SecretVariable, StringVariable +from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable from models.workflow import Workflow