merge main

This commit is contained in:
Joel 2024-10-28 10:51:02 +08:00
commit 765eb282f3
858 changed files with 16206 additions and 17932 deletions

2
.gitignore vendored
View File

@ -175,6 +175,8 @@ docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/nginx/conf.d/default.conf
docker/nginx/ssl/*
!docker/nginx/ssl/.gitkeep
docker/middleware.env
sdks/python-client/build

View File

@ -7,7 +7,8 @@ Dify is licensed under the Apache License 2.0, with the following additional con
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
b. LOGO and copyright information: In the process of using Dify's frontend, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend.
- Frontend Definition: For the purposes of this license, the "frontend" of Dify includes all components located in the `web/` directory when running Dify from the raw source code, or the "web" image when running Dify with Docker.
Please contact business@dify.ai by email to inquire about licensing matters.

View File

@ -1,5 +1,9 @@
![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab)
<p align="center">
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast</a>
</p>
<p align="center">
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·
@ -168,7 +172,7 @@ Star Dify on GitHub and be instantly notified of new releases.
> Before installing Dify, make sure your machine meets the following minimum system requirements:
>
>- CPU >= 2 Core
>- RAM >= 4GB
>- RAM >= 4 GiB
</br>

View File

@ -154,7 +154,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
我们提供[ Dify 云服务](https://dify.ai),任何人都可以零设置尝试。它提供了自部署版本的所有功能,并在沙盒计划中包含 200 次免费的 GPT-4 调用。
- **自托管 Dify 社区版</br>**
使用这个[入门指南](#quick-start)快速在您的环境中运行 Dify。
使用这个[入门指南](#快速启动)快速在您的环境中运行 Dify。
使用我们的[文档](https://docs.dify.ai)进行进一步的参考和更深入的说明。
- **面向企业/组织的 Dify</br>**
@ -174,7 +174,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
在安装 Dify 之前,请确保您的机器满足以下最低系统要求:
- CPU >= 2 Core
- RAM >= 4GB
- RAM >= 4 GiB
### 快速启动

View File

@ -31,8 +31,17 @@ REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_USE_SSL=false
REDIS_DB=0
# redis Sentinel configuration.
REDIS_USE_SENTINEL=false
REDIS_SENTINELS=
REDIS_SENTINEL_SERVICE_NAME=
REDIS_SENTINEL_USERNAME=
REDIS_SENTINEL_PASSWORD=
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
# PostgreSQL database configuration
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
@ -42,7 +51,7 @@ DB_DATABASE=dify
# Storage configuration
# use for store upload files, private keys...
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs, supabase
# storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage
S3_USE_AWS_MANAGED_IAM=false
@ -111,7 +120,7 @@ SUPABASE_URL=your-server-url
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb, upstash
VECTOR_STORE=weaviate
# Weaviate configuration
@ -220,6 +229,10 @@ BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
# Upstash configuration
UPSTASH_VECTOR_URL=your-server-url
UPSTASH_VECTOR_TOKEN=your-access-token
# ViKingDB configuration
VIKINGDB_ACCESS_KEY=your-ak
VIKINGDB_SECRET_KEY=your-sk
@ -233,10 +246,13 @@ VIKINGDB_SOCKET_TIMEOUT=30
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model Configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
# Mail configuration, support: resend, smtp
MAIL_TYPE=
@ -302,6 +318,10 @@ RESPECT_XFORWARD_HEADERS_ENABLED=false
# Log file path
LOG_FILE=
# Log file max size, the unit is MB
LOG_FILE_MAX_SIZE=20
# Log file max backup count
LOG_FILE_BACKUP_COUNT=5
# Indexing configuration
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
@ -310,6 +330,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
MAX_VARIABLE_SIZE=204800
# App configuration
APP_MAX_EXECUTION_TIME=1200
@ -327,3 +348,6 @@ POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5

View File

@ -1,8 +1,15 @@
{
"version": "0.2.0",
"compounds": [
{
"name": "Launch Flask and Celery",
"configurations": ["Python: Flask", "Python: Celery"]
}
],
"configurations": [
{
"name": "Python: Flask",
"consoleName": "Flask",
"type": "debugpy",
"request": "launch",
"python": "${workspaceFolder}/.venv/bin/python",
@ -17,12 +24,12 @@
},
"args": [
"run",
"--host=0.0.0.0",
"--port=5001"
]
},
{
"name": "Python: Celery",
"consoleName": "Celery",
"type": "debugpy",
"request": "launch",
"python": "${workspaceFolder}/.venv/bin/python",
@ -45,10 +52,10 @@
"-c",
"1",
"--loglevel",
"info",
"DEBUG",
"-Q",
"dataset,generation,mail,ops_trace,app_deletion"
]
},
}
]
}

View File

@ -55,7 +55,9 @@ RUN apt-get update \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
# install a chinese font to support the use of tools like matplotlib
&& apt-get install -y fonts-noto-cjk \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*

View File

@ -1,5 +1,7 @@
import os
from configs import dify_config
if os.environ.get("DEBUG", "false").lower() != "true":
from gevent import monkey
@ -20,6 +22,7 @@ from app_factory import create_app
# DO NOT REMOVE BELOW
from events import event_handlers # noqa: F401
from extensions.ext_database import db
# TODO: Find a way to avoid importing models here
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401
@ -35,17 +38,11 @@ if hasattr(time, "tzset"):
time.tzset()
# -------------
# Configuration
# -------------
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
# create app
app = create_app()
celery = app.extensions["celery"]
if app.config.get("TESTING"):
if dify_config.TESTING:
print("App is running in TESTING mode")
@ -53,15 +50,15 @@ if app.config.get("TESTING"):
def after_request(response):
"""Add Version headers to the response."""
response.set_cookie("remember_token", "", expires=0)
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
return response
@app.route("/health")
def health():
return Response(
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
status=200,
content_type="application/json",
)

View File

@ -10,9 +10,6 @@ if os.environ.get("DEBUG", "false").lower() != "true":
grpc.experimental.gevent.init_gevent()
import json
import logging
import sys
from logging.handlers import RotatingFileHandler
from flask import Flask, Response, request
from flask_cors import CORS
@ -27,6 +24,7 @@ from extensions import (
ext_compress,
ext_database,
ext_hosting_provider,
ext_logging,
ext_login,
ext_mail,
ext_migrate,
@ -70,43 +68,7 @@ def create_flask_app_with_configs() -> Flask:
def create_app() -> Flask:
app = create_flask_app_with_configs()
app.secret_key = app.config["SECRET_KEY"]
log_handlers = None
log_file = app.config.get("LOG_FILE")
if log_file:
log_dir = os.path.dirname(log_file)
os.makedirs(log_dir, exist_ok=True)
log_handlers = [
RotatingFileHandler(
filename=log_file,
maxBytes=1024 * 1024 * 1024,
backupCount=5,
),
logging.StreamHandler(sys.stdout),
]
logging.basicConfig(
level=app.config.get("LOG_LEVEL"),
format=app.config.get("LOG_FORMAT"),
datefmt=app.config.get("LOG_DATEFORMAT"),
handlers=log_handlers,
force=True,
)
log_tz = app.config.get("LOG_TZ")
if log_tz:
from datetime import datetime
import pytz
timezone = pytz.timezone(log_tz)
def time_converter(seconds):
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
for handler in logging.root.handlers:
handler.formatter.converter = time_converter
app.secret_key = dify_config.SECRET_KEY
initialize_extensions(app)
register_blueprints(app)
register_commands(app)
@ -117,6 +79,7 @@ def create_app() -> Flask:
def initialize_extensions(app):
# Since the application instance is now created, pass it to each Flask
# extension instance to bind it to the Flask application instance (app)
ext_logging.init_app(app)
ext_compress.init_app(app)
ext_code_based_extension.init()
ext_database.init_app(app)
@ -187,7 +150,7 @@ def register_blueprints(app):
CORS(
web_bp,
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
@ -198,7 +161,7 @@ def register_blueprints(app):
CORS(
console_app_bp,
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],

View File

@ -19,7 +19,7 @@ from extensions.ext_redis import redis_client
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
from models.account import Tenant
from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
@ -277,6 +277,7 @@ def migrate_knowledge_vector_database():
VectorType.TENCENT,
VectorType.BAIDU,
VectorType.VIKINGDB,
VectorType.UPSTASH,
}
page = 1
while True:

View File

@ -1,6 +1,15 @@
from typing import Annotated, Optional
from typing import Annotated, Literal, Optional
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
from pydantic import (
AliasChoices,
Field,
HttpUrl,
NegativeInt,
NonNegativeInt,
PositiveFloat,
PositiveInt,
computed_field,
)
from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig
@ -11,16 +20,31 @@ class SecurityConfig(BaseSettings):
Security-related configurations for the application
"""
SECRET_KEY: Optional[str] = Field(
SECRET_KEY: str = Field(
description="Secret key for secure session cookie signing."
"Make sure you are changing this key for your deployment with a strong key."
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
default=None,
default="",
)
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
description="Duration in hours for which a password reset token remains valid",
default=24,
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
description="Duration in minutes for which a password reset token remains valid",
default=5,
)
LOGIN_DISABLED: bool = Field(
description="Whether to disable login checks",
default=False,
)
ADMIN_API_KEY_ENABLE: bool = Field(
description="Whether to enable admin api key for authentication",
default=False,
)
ADMIN_API_KEY: Optional[str] = Field(
description="admin api key for authentication",
default=None,
)
@ -177,6 +201,16 @@ class FileUploadConfig(BaseSettings):
default=10,
)
UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="video file size limit in Megabytes for uploading files",
default=100,
)
UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="audio file size limit in Megabytes for uploading files",
default=50,
)
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
description="Maximum number of files allowed in a batch upload operation",
default=20,
@ -285,6 +319,16 @@ class LoggingConfig(BaseSettings):
default=None,
)
LOG_FILE_MAX_SIZE: PositiveInt = Field(
description="Maximum file size for file rotation retention, the unit is megabytes (MB)",
default=20,
)
LOG_FILE_BACKUP_COUNT: PositiveInt = Field(
description="Maximum file backup count file rotation retention",
default=5,
)
LOG_FORMAT: str = Field(
description="Format string for log messages",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
@ -355,8 +399,8 @@ class WorkflowConfig(BaseSettings):
)
MAX_VARIABLE_SIZE: PositiveInt = Field(
description="Maximum size in bytes for a single variable in workflows. Default to 5KB.",
default=5 * 1024,
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
default=200 * 1024,
)
@ -473,12 +517,18 @@ class MailConfig(BaseSettings):
default=False,
)
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,
)
class RagEtlConfig(BaseSettings):
"""
Configuration for RAG ETL processes
"""
# TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config
ETL_TYPE: str = Field(
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
default="dify",
@ -521,6 +571,11 @@ class DataSetConfig(BaseSettings):
default=False,
)
TIDB_SERVERLESS_NUMBER: PositiveInt = Field(
description="number of tidb serverless cluster",
default=500,
)
class WorkspaceConfig(BaseSettings):
"""
@ -545,7 +600,7 @@ class IndexingConfig(BaseSettings):
class ImageFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
@ -614,6 +669,33 @@ class PositionConfig(BaseSettings):
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class LoginConfig(BaseSettings):
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
description="whether to enable email code login",
default=False,
)
ENABLE_EMAIL_PASSWORD_LOGIN: bool = Field(
description="whether to enable email password login",
default=True,
)
ENABLE_SOCIAL_OAUTH_LOGIN: bool = Field(
description="whether to enable github/google oauth login",
default=False,
)
EMAIL_CODE_LOGIN_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
description="expiry time in minutes for email code login token",
default=5,
)
ALLOW_REGISTER: bool = Field(
description="whether to enable register",
default=False,
)
ALLOW_CREATE_WORKSPACE: bool = Field(
description="whether to enable create workspace",
default=False,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -639,6 +721,7 @@ class FeatureConfig(
UpdateConfig,
WorkflowConfig,
WorkspaceConfig,
LoginConfig,
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,

View File

@ -27,7 +27,9 @@ from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
from configs.middleware.vdb.qdrant_config import QdrantConfig
from configs.middleware.vdb.relyt_config import RelytConfig
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
from configs.middleware.vdb.upstash_config import UpstashConfig
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
from configs.middleware.vdb.weaviate_config import WeaviateConfig
@ -35,7 +37,8 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field(
description="Type of storage to use."
" Options: 'local', 's3', 'azure-blob', 'aliyun-oss', 'google-storage'. Default is 'local'.",
" Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', "
"'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.",
default="local",
)
@ -52,6 +55,11 @@ class VectorStoreConfig(BaseSettings):
default=None,
)
VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field(
description="Enable whitelist for vector store.",
default=False,
)
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
@ -245,5 +253,7 @@ class MiddlewareConfig(
ElasticsearchConfig,
InternalTestConfig,
VikingDBConfig,
UpstashConfig,
TidbOnQdrantConfig,
):
pass

View File

@ -0,0 +1,65 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class TidbOnQdrantConfig(BaseSettings):
"""
Tidb on Qdrant configs
"""
TIDB_ON_QDRANT_URL: Optional[str] = Field(
description="Tidb on Qdrant url",
default=None,
)
TIDB_ON_QDRANT_API_KEY: Optional[str] = Field(
description="Tidb on Qdrant api key",
default=None,
)
TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
description="Tidb on Qdrant client timeout in seconds",
default=20,
)
TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field(
description="whether enable grpc support for Tidb on Qdrant connection",
default=False,
)
TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field(
description="Tidb on Qdrant grpc port",
default=6334,
)
TIDB_PUBLIC_KEY: Optional[str] = Field(
description="Tidb account public key",
default=None,
)
TIDB_PRIVATE_KEY: Optional[str] = Field(
description="Tidb account private key",
default=None,
)
TIDB_API_URL: Optional[str] = Field(
description="Tidb API url",
default=None,
)
TIDB_IAM_API_URL: Optional[str] = Field(
description="Tidb IAM API url",
default=None,
)
TIDB_REGION: Optional[str] = Field(
description="Tidb serverless region",
default="regions/aws-us-east-1",
)
TIDB_PROJECT_ID: Optional[str] = Field(
description="Tidb project id",
default=None,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.9.2",
default="0.10.1",
)
COMMIT_SHA: str = Field(

View File

@ -1,2 +1,24 @@
from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000"
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
else:
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])

View File

@ -1,7 +1,9 @@
from contextvars import ContextVar
from typing import TYPE_CHECKING
from core.workflow.entities.variable_pool import VariablePool
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")

View File

@ -1,10 +1,10 @@
import os
from functools import wraps
from flask import request
from flask_restful import Resource, reqparse
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
from controllers.console.wraps import only_edition_cloud
@ -15,7 +15,7 @@ from models.model import App, InstalledApp, RecommendedApp
def admin_required(view):
@wraps(view)
def decorated(*args, **kwargs):
if not os.getenv("ADMIN_API_KEY"):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
auth_header = request.headers.get("Authorization")
@ -31,7 +31,7 @@ def admin_required(view):
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if os.getenv("ADMIN_API_KEY") != auth_token:
if dify_config.ADMIN_API_KEY != auth_token:
raise Unauthorized("API key is invalid.")
return view(*args, **kwargs)

View File

@ -22,7 +22,8 @@ from fields.conversation_fields import (
)
from libs.helper import DatetimeString
from libs.login import login_required
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
class CompletionConversationApi(Resource):

View File

@ -52,4 +52,39 @@ class RuleGenerateApi(Resource):
return rules
class RuleCodeGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
args = parser.parse_args()
account = current_user
CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024"))
try:
code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["code_language"],
max_tokens=CODE_GENERATION_MAX_TOKENS,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
return code_result
api.add_resource(RuleGenerateApi, "/rule-generate")
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")

View File

@ -105,6 +105,8 @@ class ChatMessageListApi(Resource):
if rest_count > 0:
has_more = True
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)

View File

@ -12,7 +12,7 @@ from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.login import login_required
from models.model import Site
from models import Site
def parse_app_site_args():

View File

@ -13,14 +13,14 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import factory
from core.errors.error import AppInvokeQuotaExceededError
from factories import variable_factory
from fields.workflow_fields import workflow_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models.model import App, AppMode
from models import App
from models.model import AppMode
from services.app_dsl_service import AppDslService
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
@ -101,9 +101,13 @@ class DraftWorkflowApi(Resource):
try:
environment_variables_list = args.get("environment_variables") or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
environment_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
conversation_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args["graph"],
@ -273,17 +277,15 @@ class DraftWorkflowRunApi(Resource):
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
)
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
return helper.compact_generate_response(response)
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
return helper.compact_generate_response(response)
class WorkflowTaskStopApi(Resource):

View File

@ -7,7 +7,8 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required
from models.model import App, AppMode
from models import App
from models.model import AppMode
from services.workflow_app_service import WorkflowAppService

View File

@ -13,7 +13,8 @@ from fields.workflow_run_fields import (
)
from libs.helper import uuid_value
from libs.login import login_required
from models.model import App, AppMode
from models import App
from models.model import AppMode
from services.workflow_run_service import WorkflowRunService

View File

@ -13,8 +13,8 @@ from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from models.workflow import WorkflowRunTriggeredFrom
class WorkflowDailyRunsStatistic(Resource):

View File

@ -5,7 +5,8 @@ from typing import Optional, Union
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models.model import App, AppMode
from models import App
from models.model import AppMode
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):

View File

@ -1,17 +1,15 @@
import base64
import datetime
import secrets
from flask import request
from flask_restful import Resource, reqparse
from constants.languages import supported_language
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.helper import StrLen, email, timezone
from libs.password import hash_password, valid_password
from models.account import AccountStatus
from services.account_service import RegisterService
from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus, Tenant
from services.account_service import AccountService, RegisterService
class ActivateCheckApi(Resource):
@ -27,8 +25,18 @@ class ActivateCheckApi(Resource):
token = args["token"]
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None}
if invitation:
data = invitation.get("data", {})
tenant: Tenant = invitation.get("tenant", None)
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None
return {
"is_valid": invitation is not None,
"data": {"workspace_name": workspace_name, "workspace_id": workspace_id, "email": invitee_email},
}
else:
return {"is_valid": False}
class ActivateApi(Resource):
@ -38,7 +46,6 @@ class ActivateApi(Resource):
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"
)
@ -54,15 +61,6 @@ class ActivateApi(Resource):
account = invitation["account"]
account.name = args["name"]
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(args["password"], salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_theme = "light"
@ -70,7 +68,9 @@ class ActivateApi(Resource):
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
return {"result": "success"}
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
return {"result": "success", "data": token_pair.model_dump()}
api.add_resource(ActivateCheckApi, "/activate/check")

View File

@ -27,5 +27,29 @@ class InvalidTokenError(BaseHTTPException):
class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = "password_reset_rate_limit_exceeded"
description = "Password reset rate limit exceeded. Try again later."
description = "Too many password reset emails have been sent. Please try again in 1 minutes."
code = 429
class EmailCodeError(BaseHTTPException):
error_code = "email_code_error"
description = "Email code is invalid or expired."
code = 400
class EmailOrPasswordMismatchError(BaseHTTPException):
error_code = "email_or_password_mismatch"
description = "The email or password is mismatched."
code = 400
class EmailPasswordLoginLimitError(BaseHTTPException):
error_code = "email_code_login_limit"
description = "Too many incorrect password attempts. Please try again later."
code = 429
class EmailCodeLoginRateLimitExceededError(BaseHTTPException):
error_code = "email_code_login_rate_limit_exceeded"
description = "Too many login emails have been sent. Please try again in 5 minutes."
code = 429

View File

@ -1,65 +1,82 @@
import base64
import logging
import secrets
from flask import request
from flask_restful import Resource, reqparse
from constants.languages import languages
from controllers.console import api
from controllers.console.auth.error import (
EmailCodeError,
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
PasswordResetRateLimitExceededError,
)
from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister
from controllers.console.setup import setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email as email_validate
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService
from services.errors.account import RateLimitExceededError
from services.account_service import AccountService, TenantService
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.feature_service import FeatureService
class ForgotPasswordSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
email = args["email"]
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if not email_validate(email):
raise InvalidEmailError()
account = Account.query.filter_by(email=email).first()
if account:
try:
AccountService.send_reset_password_email(account=account)
except RateLimitExceededError:
logging.warning(f"Rate limit exceeded for email: {account.email}")
raise PasswordResetRateLimitExceededError()
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
# Return success to avoid revealing email registration status
logging.warning(f"Attempt to reset password for unregistered email: {email}")
language = "en-US"
return {"result": "success"}
account = Account.query.filter_by(email=args["email"]).first()
token = None
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"}
else:
raise NotAllowedRegister()
else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
return {"result": "success", "data": token}
class ForgotPasswordCheckApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
token = args["token"]
reset_data = AccountService.get_reset_password_data(token)
user_email = args["email"]
if reset_data is None:
return {"is_valid": False, "email": None}
return {"is_valid": True, "email": reset_data.get("email")}
token_data = AccountService.get_reset_password_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
raise EmailCodeError()
return {"is_valid": True, "email": token_data.get("email")}
class ForgotPasswordResetApi(Resource):
@ -92,9 +109,26 @@ class ForgotPasswordResetApi(Resource):
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get("email")).first()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
if account:
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
tenant = TenantService.get_join_tenants(account)
if not tenant and not FeatureService.get_system_features().is_allow_create_workspace:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
else:
try:
account = AccountService.create_account_and_tenant(
email=reset_data.get("email"),
name=reset_data.get("email"),
password=password_confirm,
interface_language=languages[0],
)
except WorkSpaceNotAllowedCreateError:
pass
return {"result": "success"}

View File

@ -5,12 +5,29 @@ from flask import request
from flask_restful import Resource, reqparse
import services
from constants.languages import languages
from controllers.console import api
from controllers.console.auth.error import (
EmailCodeError,
EmailOrPasswordMismatchError,
EmailPasswordLoginLimitError,
InvalidEmailError,
InvalidTokenError,
)
from controllers.console.error import (
AccountBannedError,
EmailSendIpLimitError,
NotAllowedCreateWorkspace,
NotAllowedRegister,
)
from controllers.console.setup import setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
from models.account import Account
from services.account_service import AccountService, TenantService
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.feature_service import FeatureService
class LoginApi(Resource):
@ -23,15 +40,43 @@ class LoginApi(Resource):
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
parser.add_argument("language", type=str, required=False, default="en-US", location="json")
args = parser.parse_args()
# todo: Verify the recaptcha
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invitation = args["invite_token"]
if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError as e:
return {"code": "unauthorized", "message": str(e)}, 401
if invitation:
data = invitation.get("data", {})
invitee_email = data.get("email") if data else None
if invitee_email != args["email"]:
raise InvalidEmailError()
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
else:
account = AccountService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"])
raise EmailOrPasswordMismatchError()
except services.errors.account.AccountNotFoundError:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"}
else:
raise NotAllowedRegister()
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@ -41,7 +86,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
@ -49,60 +94,111 @@ class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"}
AccountService.logout(account=account)
flask_login.logout_user()
return {"result": "success"}
class ResetPasswordApi(Resource):
class ResetPasswordSendEmailApi(Resource):
@setup_required
def get(self):
# parser = reqparse.RequestParser()
# parser.add_argument('email', type=email, required=True, location='json')
# args = parser.parse_args()
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
# import mailchimp_transactional as MailchimpTransactional
# from mailchimp_transactional.api_client import ApiClientError
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
# account = {'email': args['email']}
# account = AccountService.get_by_email(args['email'])
# if account is None:
# raise ValueError('Email not found')
# new_password = AccountService.generate_password()
# AccountService.update_password(account, new_password)
account = AccountService.get_user_through_email(args["email"])
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
else:
raise NotAllowedRegister()
else:
token = AccountService.send_reset_password_email(account=account, language=language)
# todo: Send email
# MAILCHIMP_API_KEY = dify_config.MAILCHIMP_TRANSACTIONAL_API_KEY
# mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY)
return {"result": "success", "data": token}
# message = {
# 'from_email': 'noreply@example.com',
# 'to': [{'email': account['email']}],
# 'subject': 'Reset your Dify password',
# 'html': """
# <p>Dear User,</p>
# <p>The Dify team has generated a new password for you, details as follows:</p>
# <p><strong>{new_password}</strong></p>
# <p>Please change your password to log in as soon as possible.</p>
# <p>Regards,</p>
# <p>The Dify Team</p>
# """
# }
# response = mailchimp.messages.send({
# 'message': message,
# # required for transactional email
# ' settings': {
# 'sandbox_mode': dify_config.MAILCHIMP_SANDBOX_MODE,
# },
# })
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
# Check if MSG was sent
# if response.status_code != 200:
# # handle error
# pass
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
return {"result": "success"}
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = AccountService.get_user_through_email(args["email"])
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
else:
raise NotAllowedRegister()
else:
token = AccountService.send_email_code_login_email(account=account, language=language)
return {"result": "success", "data": token}
class EmailCodeLoginApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json")
args = parser.parse_args()
user_email = args["email"]
token_data = AccountService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args["token"])
account = AccountService.get_user_through_email(user_email)
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
if not FeatureService.get_system_features().is_allow_create_workspace:
raise NotAllowedCreateWorkspace()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
if account is None:
try:
account = AccountService.create_account_and_tenant(
email=user_email, name=user_email, interface_language=languages[0]
)
except WorkSpaceNotAllowedCreateError:
return NotAllowedCreateWorkspace()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
class RefreshTokenApi(Resource):
@ -120,4 +216,7 @@ class RefreshTokenApi(Resource):
api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
api.add_resource(RefreshTokenApi, "/refresh-token")

View File

@ -5,14 +5,20 @@ from typing import Optional
import requests
from flask import current_app, redirect, request
from flask_restful import Resource
from werkzeug.exceptions import Unauthorized
from configs import dify_config
from constants.languages import languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models.account import Account, AccountStatus
from models import Account
from models.account import AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFoundError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService
from .. import api
@ -42,6 +48,7 @@ def get_oauth_providers():
class OAuthLogin(Resource):
def get(self, provider: str):
invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
@ -49,7 +56,7 @@ class OAuthLogin(Resource):
if not oauth_provider:
return {"error": "Invalid provider"}, 400
auth_url = oauth_provider.get_authorization_url()
auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
return redirect(auth_url)
@ -62,6 +69,11 @@ class OAuthCallback(Resource):
return {"error": "Invalid provider"}, 400
code = request.args.get("code")
state = request.args.get("state")
invite_token = None
if state:
invite_token = state
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
@ -69,17 +81,43 @@ class OAuthCallback(Resource):
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
return {"error": "OAuth process failed"}, 400
account = _generate_account(provider, user_info)
if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService._get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
if invitation_email != user_info.email:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
account = _generate_account(provider, user_info)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
# Check account status
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.BANNED.value:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
TenantService.create_owner_tenant_if_not_exist(account)
try:
TenantService.create_owner_tenant_if_not_exist(account)
except Unauthorized:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
token_pair = AccountService.login(
account=account,
@ -104,8 +142,20 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
if not FeatureService.get_system_features().is_allow_create_workspace:
raise WorkSpaceNotAllowedCreateError()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
if not account:
# Create account
if not FeatureService.get_system_features().is_allow_register:
raise AccountNotFoundError()
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider

View File

@ -15,8 +15,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.login import login_required
from models.dataset import Document
from models.source import DataSourceOauthBinding
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
from tasks.document_indexing_sync_task import document_indexing_sync_task

View File

@ -24,8 +24,8 @@ from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from libs.login import login_required
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from models.model import ApiToken, UploadFile
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -102,6 +102,13 @@ class DatasetListApi(Resource):
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"indexing_technique",
type=str,
@ -140,6 +147,7 @@ class DatasetListApi(Resource):
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
@ -619,6 +627,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
@ -630,6 +639,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT
):
return {
"retrieval_method": [
@ -657,6 +667,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (

View File

@ -46,8 +46,7 @@ from fields.document_fields import (
document_with_segments_fields,
)
from libs.login import login_required
from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
from models.model import UploadFile
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from services.dataset_service import DatasetService, DocumentService
from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task

View File

@ -24,7 +24,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields
from libs.login import login_required
from models.dataset import DocumentSegment
from models import DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task

View File

@ -1,9 +1,12 @@
import urllib.parse
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal_with
from flask_restful import Resource, marshal_with, reqparse
import services
from configs import dify_config
from constants import DOCUMENT_EXTENSIONS
from controllers.console import api
from controllers.console.datasets.error import (
FileTooLargeError,
@ -13,9 +16,10 @@ from controllers.console.datasets.error import (
)
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from fields.file_fields import file_fields, upload_config_fields
from core.helper import ssrf_proxy
from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields
from libs.login import login_required
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService
from services.file_service import FileService
PREVIEW_WORDS_LIMIT = 3000
@ -26,13 +30,12 @@ class FileApi(Resource):
@account_initialization_required
@marshal_with(upload_config_fields)
def get(self):
file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
return {
"file_size_limit": file_size_limit,
"batch_count_limit": batch_count_limit,
"image_file_size_limit": image_file_size_limit,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
}, 200
@setup_required
@ -44,6 +47,10 @@ class FileApi(Resource):
# get file from request
file = request.files["file"]
parser = reqparse.RequestParser()
parser.add_argument("source", type=str, required=False, location="args")
source = parser.parse_args().get("source")
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -51,7 +58,7 @@ class FileApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
try:
upload_file = FileService.upload_file(file, current_user)
upload_file = FileService.upload_file(file=file, user=current_user, source=source)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
@ -75,11 +82,24 @@ class FileSupportTypeApi(Resource):
@login_required
@account_initialization_required
def get(self):
etl_type = dify_config.ETL_TYPE
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
return {"allowed_extensions": allowed_extensions}
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
class RemoteFileInfoApi(Resource):
@marshal_with(remote_file_info_fields)
def get(self, url):
decoded_url = urllib.parse.unquote(url)
try:
response = ssrf_proxy.head(decoded_url)
return {
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(response.headers.get("Content-Length", 0)),
}
except Exception as e:
return {"error": str(e)}, 400
api.add_resource(FileApi, "/files/upload")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
api.add_resource(FileSupportTypeApi, "/files/support-type")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")

View File

@ -38,3 +38,27 @@ class AlreadyActivateError(BaseHTTPException):
error_code = "already_activate"
description = "Auth Token is invalid or account already activated, please check again."
code = 403
class NotAllowedCreateWorkspace(BaseHTTPException):
error_code = "not_allowed_create_workspace"
description = "Workspace not found, please contact system admin to invite you to join in a workspace."
code = 400
class AccountBannedError(BaseHTTPException):
error_code = "account_banned"
description = "Account is banned."
code = 400
class NotAllowedRegister(BaseHTTPException):
error_code = "unauthorized"
description = "Account not found."
code = 400
class EmailSendIpLimitError(BaseHTTPException):
error_code = "email_send_ip_limit"
description = "Too many emails have been sent from this IP address recently. Please try again later."
code = 429

View File

@ -11,7 +11,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.login import login_required
from models.model import App, InstalledApp, RecommendedApp
from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService

View File

@ -21,7 +21,12 @@ class AppParameterApi(InstalledAppResource):
"options": fields.List(fields.String),
}
system_parameters_fields = {"image_file_size_limit": fields.String}
system_parameters_fields = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
}
parameters_fields = {
"opening_statement": fields.String,
@ -82,7 +87,12 @@ class AppParameterApi(InstalledAppResource):
}
},
),
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
},
}

View File

@ -18,7 +18,7 @@ message_fields = {
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
}

View File

@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import login_required
from models.model import InstalledApp
from models import InstalledApp
def installed_app_required(view=None):

View File

@ -20,7 +20,7 @@ from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.helper import TimestampField, timezone
from libs.login import login_required
from models.account import AccountIntegrate, InvitationCode
from models import AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError

View File

@ -360,16 +360,15 @@ class ToolWorkflowProviderCreateApi(Resource):
args = reqparser.parse_args()
return WorkflowToolManageService.create_workflow_tool(
user_id,
tenant_id,
args["workflow_app_id"],
args["name"],
args["label"],
args["icon"],
args["description"],
args["parameters"],
args["privacy_policy"],
args.get("labels", []),
user_id=user_id,
tenant_id=tenant_id,
workflow_app_id=args["workflow_app_id"],
name=args["name"],
label=args["label"],
icon=args["icon"],
description=args["description"],
parameters=args["parameters"],
privacy_policy=args["privacy_policy"],
)

View File

@ -198,7 +198,7 @@ class WebappLogoWorkspaceApi(Resource):
raise UnsupportedFileTypeError()
try:
upload_file = FileService.upload_file(file, current_user, True)
upload_file = FileService.upload_file(file=file, user=current_user)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)

View File

@ -1,5 +1,5 @@
from flask import Response, request
from flask_restful import Resource
from flask_restful import Resource, reqparse
from werkzeug.exceptions import NotFound
import services
@ -10,6 +10,10 @@ from services.file_service import FileService
class ImagePreviewApi(Resource):
"""
Deprecated
"""
def get(self, file_id):
file_id = str(file_id)
@ -21,13 +25,57 @@ class ImagePreviewApi(Resource):
return {"content": "Invalid request."}, 400
try:
generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign)
generator, mimetype = FileService.get_image_preview(
file_id=file_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
class FilePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
parser = reqparse.RequestParser()
parser.add_argument("timestamp", type=str, required=True, location="args")
parser.add_argument("nonce", type=str, required=True, location="args")
parser.add_argument("sign", type=str, required=True, location="args")
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
args = parser.parse_args()
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
return {"content": "Invalid request."}, 400
try:
generator, upload_file = FileService.get_file_generator_by_file_id(
file_id=file_id,
timestamp=args["timestamp"],
nonce=args["nonce"],
sign=args["sign"],
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
response = Response(
generator,
mimetype=upload_file.mime_type,
direct_passthrough=True,
headers={},
)
if upload_file.size > 0:
response.headers["Content-Length"] = str(upload_file.size)
if args["as_attachment"]:
response.headers["Content-Disposition"] = f"attachment; filename={upload_file.name}"
return response
class WorkspaceWebappLogoApi(Resource):
def get(self, workspace_id):
workspace_id = str(workspace_id)
@ -49,4 +97,5 @@ class WorkspaceWebappLogoApi(Resource):
api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/file-preview")
api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")

View File

@ -16,6 +16,7 @@ class ToolFilePreviewApi(Resource):
parser.add_argument("timestamp", type=str, required=True, location="args")
parser.add_argument("nonce", type=str, required=True, location="args")
parser.add_argument("sign", type=str, required=True, location="args")
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
args = parser.parse_args()
@ -28,18 +29,27 @@ class ToolFilePreviewApi(Resource):
raise Forbidden("Invalid request.")
try:
result = ToolFileManager.get_file_generator_by_tool_file_id(
stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id(
file_id,
)
if not result:
if not stream or not tool_file:
raise NotFound("file is not found")
generator, mimetype = result
except Exception:
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
response = Response(
stream,
mimetype=tool_file.mimetype,
direct_passthrough=True,
headers={},
)
if tool_file.size > 0:
response.headers["Content-Length"] = str(tool_file.size)
if args["as_attachment"]:
response.headers["Content-Disposition"] = f"attachment; filename={tool_file.name}"
return response
api.add_resource(ToolFilePreviewApi, "/files/tools/<uuid:file_id>.<string:extension>")

View File

@ -21,7 +21,12 @@ class AppParameterApi(Resource):
"options": fields.List(fields.String),
}
system_parameters_fields = {"image_file_size_limit": fields.String}
system_parameters_fields = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
}
parameters_fields = {
"opening_statement": fields.String,
@ -81,7 +86,12 @@ class AppParameterApi(Resource):
}
},
),
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
},
}

View File

@ -48,7 +48,7 @@ class MessageListApi(Resource):
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"message_files": fields.List(fields.String, attribute="files"),
"message_files": fields.List(fields.Nested(message_file_fields)),
}
message_fields = {
@ -58,7 +58,7 @@ class MessageListApi(Resource):
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,

View File

@ -66,6 +66,13 @@ class DatasetListApi(DatasetApiResource):
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"indexing_technique",
type=str,
@ -108,6 +115,7 @@ class DatasetListApi(DatasetApiResource):
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=current_user,
permission=args["permission"],

View File

@ -21,7 +21,12 @@ class AppParameterApi(WebApiResource):
"options": fields.List(fields.String),
}
system_parameters_fields = {"image_file_size_limit": fields.String}
system_parameters_fields = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
}
parameters_fields = {
"opening_statement": fields.String,
@ -80,7 +85,12 @@ class AppParameterApi(WebApiResource):
}
},
),
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
},
}

View File

@ -1,11 +1,14 @@
import urllib.parse
from flask import request
from flask_restful import marshal_with
from flask_restful import marshal_with, reqparse
import services
from controllers.web import api
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
from controllers.web.wraps import WebApiResource
from fields.file_fields import file_fields
from core.helper import ssrf_proxy
from fields.file_fields import file_fields, remote_file_info_fields
from services.file_service import FileService
@ -15,6 +18,10 @@ class FileApi(WebApiResource):
# get file from request
file = request.files["file"]
parser = reqparse.RequestParser()
parser.add_argument("source", type=str, required=False, location="args")
source = parser.parse_args().get("source")
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -22,7 +29,7 @@ class FileApi(WebApiResource):
if len(request.files) > 1:
raise TooManyFilesError()
try:
upload_file = FileService.upload_file(file, end_user)
upload_file = FileService.upload_file(file=file, user=end_user, source=source)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
@ -31,4 +38,19 @@ class FileApi(WebApiResource):
return upload_file, 201
class RemoteFileInfoApi(WebApiResource):
@marshal_with(remote_file_info_fields)
def get(self, url):
decoded_url = urllib.parse.unquote(url)
try:
response = ssrf_proxy.head(decoded_url)
return {
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(response.headers.get("Content-Length", -1)),
}
except Exception as e:
return {"error": str(e)}, 400
api.add_resource(FileApi, "/files/upload")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")

View File

@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields
from fields.raws import FilesContainedField
from libs import helper
from libs.helper import TimestampField, uuid_value
from models.model import AppMode
@ -58,10 +59,10 @@ class MessageListApi(WebApiResource):
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": fields.Raw,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,

View File

@ -17,7 +17,7 @@ message_fields = {
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
}

View File

@ -16,13 +16,14 @@ from core.app.entities.app_invoke_entities import (
)
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file.message_file_parser import MessageFileParser
from core.file import file_manager
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMUsage,
PromptMessage,
PromptMessageContent,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
@ -40,9 +41,9 @@ from core.tools.entities.tool_entities import (
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db
from models.model import Conversation, Message, MessageAgentThought
from factories import file_factory
from models.model import Conversation, Message, MessageAgentThought, MessageFile
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@ -66,23 +67,6 @@ class BaseAgentRunner(AppRunner):
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param application_generate_entity: application generate entity
:param conversation: conversation
:param app_config: app generate entity
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param memory: memory
:param prompt_messages: prompt messages
:param variables_pool: variables pool
:param db_variables: db variables
:param model_instance: model instance
"""
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
self.conversation = conversation
@ -180,7 +164,7 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
parameter_type = parameter.type.as_normal_type()
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
@ -265,7 +249,7 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
parameter_type = parameter.type.as_normal_type()
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
@ -511,26 +495,24 @@ class BaseAgentRunner(AppRunner):
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
message_file_parser = MessageFileParser(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
)
files = message.message_files
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
)
else:
file_objs = []
if not file_objs:
return UserPromptMessage(content=message.query)
else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
return UserPromptMessage(content=prompt_message_contents)
else:

View File

@ -1,9 +1,11 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import (
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
@ -32,9 +34,10 @@ class CotChatAgentRunner(CotAgentRunner):
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:

View File

@ -7,10 +7,15 @@ from typing import Any, Optional, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
@ -390,9 +395,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:

View File

@ -53,12 +53,11 @@ class BasicVariablesConfigManager:
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description"),
description=variable.get("description") or "",
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options"),
default=variable.get("default"),
options=variable.get("options") or [],
)
)

View File

@ -1,11 +1,12 @@
from collections.abc import Sequence
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field, field_validator
from core.file.file_obj import FileExtraConfig
from core.file import FileExtraConfig, FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import PromptMessageRole
from models import AppMode
from models.model import AppMode
class ModelConfigEntity(BaseModel):
@ -69,7 +70,7 @@ class PromptTemplateEntity(BaseModel):
ADVANCED = "advanced"
@classmethod
def value_of(cls, value: str) -> "PromptType":
def value_of(cls, value: str):
"""
Get value of given mode.
@ -93,6 +94,8 @@ class VariableEntityType(str, Enum):
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
class VariableEntity(BaseModel):
@ -102,13 +105,24 @@ class VariableEntity(BaseModel):
variable: str
label: str
description: Optional[str] = None
description: str = ""
type: VariableEntityType
required: bool = False
max_length: Optional[int] = None
options: Optional[list[str]] = None
default: Optional[str] = None
hint: Optional[str] = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
@field_validator("description", mode="before")
@classmethod
def convert_none_description(cls, v: Any) -> str:
return v or ""
@field_validator("options", mode="before")
@classmethod
def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or []
class ExternalDataVariableEntity(BaseModel):
@ -136,7 +150,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
MULTIPLE = "multiple"
@classmethod
def value_of(cls, value: str) -> "RetrieveStrategy":
def value_of(cls, value: str):
"""
Get value of given mode.

View File

@ -1,12 +1,13 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.file.file_obj import FileExtraConfig
from core.file.models import FileExtraConfig
from models import FileUploadConfig
class FileUploadConfigManager:
@classmethod
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
"""
Convert model config to model config
@ -15,19 +16,21 @@ class FileUploadConfigManager:
"""
file_upload_dict = config.get("file_upload")
if file_upload_dict:
if file_upload_dict.get("image"):
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
image_config = {
"number_limits": file_upload_dict["image"]["number_limits"],
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
if file_upload_dict.get("enabled"):
transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get(
"allowed_upload_methods", []
)
data = {
"image_config": {
"number_limits": file_upload_dict["number_limits"],
"transfer_methods": transform_methods,
}
}
if is_vision:
image_config["detail"] = file_upload_dict["image"]["detail"]
if is_vision:
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
return FileExtraConfig(image_config=image_config)
return None
return FileExtraConfig.model_validate(data)
@classmethod
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
@ -39,29 +42,7 @@ class FileUploadConfigManager:
"""
if not config.get("file_upload"):
config["file_upload"] = {}
if not isinstance(config["file_upload"], dict):
raise ValueError("file_upload must be of dict type")
# check image config
if not config["file_upload"].get("image"):
config["file_upload"]["image"] = {"enabled": False}
if config["file_upload"]["image"]["enabled"]:
number_limits = config["file_upload"]["image"]["number_limits"]
if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]")
if is_vision:
detail = config["file_upload"]["image"]["detail"]
if detail not in {"high", "low"}:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in {"remote_url", "local_file"}:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
else:
FileUploadConfig.model_validate(config["file_upload"])
return config, ["file_upload"]

View File

@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager:
# variables
for variable in user_input_form:
variables.append(VariableEntity(**variable))
variables.append(VariableEntity.model_validate(variable))
return variables

View File

@ -21,11 +21,12 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
@ -96,10 +97,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
file_objs = []
@ -107,8 +114,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id, user_id)
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
@ -120,7 +128,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View File

@ -1,31 +1,27 @@
import logging
import os
from collections.abc import Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
from core.moderation.base import ModerationError
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, WorkflowType
@ -44,12 +40,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation: Conversation,
message: Message,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
"""
super().__init__(queue_manager)
self.application_generate_entity = application_generate_entity
@ -57,10 +47,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self.message = message
def run(self) -> None:
"""
Run application
:return:
"""
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
@ -81,7 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run:
@ -201,15 +187,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query: str,
message_id: str,
) -> bool:
"""
Handle input moderation
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
@ -229,14 +206,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def handle_annotation_reply(
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
) -> bool:
"""
Handle annotation reply
:param app_record: app record
:param message: message
:param query: query
:param app_generate_entity: application generate entity
"""
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
@ -258,8 +227,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param text: text
:return:
"""
self._publish_event(QueueTextChunkEvent(text=text))

View File

@ -1,7 +1,7 @@
import json
import logging
import time
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@ -9,6 +9,7 @@ from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
@ -50,10 +51,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from events.message_event import message_was_created
from extensions.ext_database import db
from models import Conversation, EndUser, Message, MessageFile
from models.account import Account
from models.model import Conversation, EndUser, Message
from models.enums import CreatedByRole
from models.workflow import (
Workflow,
WorkflowNodeExecution,
@ -120,6 +123,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._wip_workflow_node_executions = {}
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
def process(self):
"""
@ -298,6 +302,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
@ -364,7 +372,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if event.outputs else None,
outputs=event.outputs,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
@ -490,10 +498,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation_name_generate_thread.join()
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
"""
Save message.
:return:
"""
self._refetch_message()
self._message.answer = self._task_state.answer
@ -501,6 +505,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message_files = [
MessageFile(
message_id=self._message.id,
type=file["type"],
transfer_method=file["transfer_method"],
url=file["remote_url"],
belongs_to="assistant",
upload_file_id=file["related_id"],
created_by_role=CreatedByRole.ACCOUNT
if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER,
created_by=self._message.from_account_id or self._message.from_end_user_id or "",
)
for file in self._recorded_files
]
db.session.add_all(message_files)
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
@ -540,7 +560,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
)
def _handle_output_moderation_chunk(self, text: str) -> bool:

View File

@ -18,12 +18,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
from factories import file_factory
from models import Account, App, EndUser
from models.enums import CreatedByRole
logger = logging.getLogger(__name__)
@ -50,7 +50,12 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@ -98,12 +103,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
file_objs = []
@ -116,8 +128,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id, user_id)
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
# init application generate entity
application_generate_entity = AgentChatAppGenerateEntity(
@ -125,7 +136,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View File

@ -1,35 +1,92 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
from core.app.app_config.entities import VariableEntityType
from core.file import File, FileExtraConfig
from factories import file_factory
if TYPE_CHECKING:
from core.app.app_config.entities import AppConfig, VariableEntity
from models.enums import CreatedByRole
class BaseAppGenerator:
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
def _prepare_user_inputs(
self,
*,
user_inputs: Optional[Mapping[str, Any]],
app_config: "AppConfig",
user_id: str,
role: "CreatedByRole",
) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
# Convert files in inputs to File
entity_dictionary = {item.variable: item for item in app_config.variables}
# Convert single file to File
files_inputs = {
k: file_factory.build_from_mapping(
mapping=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
}
# Convert list of files to File
file_list_inputs = {
k: file_factory.build_from_mappings(
mappings=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
if isinstance(v, list)
# Ensure skip List<File>
and all(isinstance(item, dict) for item in v)
and entity_dictionary[k].type == VariableEntityType.FILE_LIST
}
# Merge all inputs
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.variable)
if var.required and not user_input_value:
raise ValueError(f"{var.variable} is required in input form")
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ""
if (
var.type
in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
}
and user_input_value
and not isinstance(user_input_value, str)
# Check if all files are converted to File
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
return user_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
user_input_value = inputs.get(var.variable)
if not user_input_value:
if var.required:
raise ValueError(f"{var.variable} is required in input form")
else:
return None
if var.type in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
} and not isinstance(user_input_value, str):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
@ -41,12 +98,24 @@ class BaseAppGenerator:
except ValueError:
raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT:
options = var.options or []
options = var.options
if user_input_value not in options:
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
if var.max_length and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
elif var.type == VariableEntityType.FILE:
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
raise ValueError(f"{var.variable} in input form must be a file")
elif var.type == VariableEntityType.FILE_LIST:
if not (
isinstance(user_input_value, list)
and (
all(isinstance(item, dict) for item in user_input_value)
or all(isinstance(item, File) for item in user_input_value)
)
):
raise ValueError(f"{var.variable} in input form must be a list of files")
return user_input_value

View File

@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation
if TYPE_CHECKING:
from core.file.file_obj import FileVar
from core.file.models import File
class AppRunner:
@ -37,7 +37,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
files: list["File"],
query: Optional[str] = None,
) -> int:
"""
@ -137,7 +137,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
files: list["File"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,

View File

@ -18,11 +18,12 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, EndUser
logger = logging.getLogger(__name__)
@ -100,12 +101,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
file_objs = []
@ -118,7 +126,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
trace_manager = TraceQueueManager(app_model.id)
trace_manager = TraceQueueManager(app_id=app_model.id)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(
@ -126,15 +134,17 @@ class ChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
stream=stream,
)
# init generate records

View File

@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser, Message
from factories import file_factory
from models import Account, App, EndUser, Message
from models.enums import CreatedByRole
from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError
@ -88,12 +88,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
tenant_id=app_model.tenant_id, config=args.get("model_config")
)
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
file_objs = []
@ -103,6 +110,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id)
# init application generate entity
@ -110,7 +118,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
inputs=self._get_cleaned_inputs(inputs, app_config),
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query,
files=file_objs,
user_id=user.id,
@ -251,10 +259,16 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["model"] = model_dict
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
file_objs = file_factory.build_from_mappings(
mappings=message.message_files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
file_objs = []

View File

@ -26,7 +26,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from models.account import Account
from models import Account
from models.enums import CreatedByRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
@ -235,13 +236,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.url,
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=("account" if account_id else "end_user"),
created_by=account_id or end_user_id,
created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER),
created_by=account_id or end_user_id or "",
)
db.session.add(message_file)
db.session.commit()

View File

@ -3,7 +3,7 @@ import logging
import os
import threading
import uuid
from collections.abc import Generator
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
@ -20,13 +20,12 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
from models.workflow import Workflow
from factories import file_factory
from models import Account, App, EndUser, Workflow
from models.enums import CreatedByRole
logger = logging.getLogger(__name__)
@ -63,49 +62,46 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
):
"""
Generate App response.
files: Sequence[Mapping[str, Any]] = args.get("files") or []
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
:param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
"""
inputs = args["inputs"]
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
system_files = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow,
)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id, user_id)
trace_manager = TraceQueueManager(
app_id=app_model.id,
user_id=user.id if isinstance(user, Account) else user.session_id,
)
inputs: Mapping[str, Any] = args["inputs"]
workflow_run_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs=self._get_cleaned_inputs(inputs, app_config),
files=file_objs,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
files=system_files,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,

View File

@ -1,21 +1,20 @@
import logging
import os
from typing import Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
from models.model import App, EndUser
from models.workflow import WorkflowType
@ -71,7 +70,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
db.session.close()
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested

View File

@ -1,4 +1,3 @@
import json
import logging
import time
from collections.abc import Generator
@ -334,9 +333,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs)
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
else None,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
)

View File

@ -20,7 +20,6 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@ -41,9 +40,9 @@ from core.workflow.graph_engine.entities.event import (
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.nodes import NodeType
from core.workflow.nodes.iteration import IterationNodeData
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
@ -137,9 +136,8 @@ class WorkflowBasedAppRunner(AppRunner):
raise ValueError("iteration node id not found in workflow graph")
# Get node class
node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping[node_type]
# init variable pool
variable_pool = VariablePool(

View File

@ -1,220 +0,0 @@
from typing import Optional
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
self.current_node_id = None
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(event=event)
elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded(event=event)
elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed(event=event)
elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk(event=event)
elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(event=event)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(event=event)
elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started(event=event)
elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next(event=event)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed(event=event)
else:
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
"""
Workflow node execute started
"""
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
self.print_text(f"Node ID: {event.node_id}", color="yellow")
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
self.print_text(f"Type: {event.node_type.value}", color="yellow")
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
"""
Workflow node execute succeeded
"""
route_node_state = event.route_node_state
self.print_text("\n[NodeRunSucceededEvent]", color="green")
self.print_text(f"Node ID: {event.node_id}", color="green")
self.print_text(f"Node Title: {event.node_data.title}", color="green")
self.print_text(f"Type: {event.node_type.value}", color="green")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color="green",
)
self.print_text(
f"Process Data: "
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="green",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="green",
)
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color="green",
)
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
"""
Workflow node execute failed
"""
route_node_state = event.route_node_state
self.print_text("\n[NodeRunFailedEvent]", color="red")
self.print_text(f"Node ID: {event.node_id}", color="red")
self.print_text(f"Node Title: {event.node_data.title}", color="red")
self.print_text(f"Type: {event.node_type.value}", color="red")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color="red")
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color="red",
)
self.print_text(
f"Process Data: "
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="red",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="red",
)
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
"""
Publish text chunk
"""
route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id
self.print_text("\n[NodeRunStreamChunkEvent]")
self.print_text(f"Node ID: {route_node_state.node_id}")
node_run_result = route_node_state.node_run_result
if node_run_result:
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
)
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
def on_workflow_parallel_completed(
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = "blue"
elif isinstance(event, ParallelBranchRunFailedEvent):
color = "red"
self.print_text(
"\n[ParallelBranchRunSucceededEvent]"
if isinstance(event, ParallelBranchRunSucceededEvent)
else "\n[ParallelBranchRunFailedEvent]",
color=color,
)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
"""
Publish iteration started
"""
self.print_text("\n[IterationRunStartedEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
"""
Publish iteration next
"""
self.print_text("\n[IterationRunNextEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
self.print_text(f"Iteration Index: {event.index}", color="blue")
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
"""
Publish iteration completed
"""
self.print_text(
"\n[IterationRunSucceededEvent]"
if isinstance(event, IterationRunSucceededEvent)
else "\n[IterationRunFailedEvent]",
color="blue",
)
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(f"{text_to_print}", end=end)
def _get_colored_text(self, text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"

View File

@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
from constants import UUID_NIL
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file.file_obj import FileVar
from core.file.models import File
from core.model_runtime.entities.model_entities import AIModelEntity
from core.ops.ops_trace_manager import TraceQueueManager
@ -23,7 +23,7 @@ class InvokeFrom(Enum):
DEBUGGER = "debugger"
@classmethod
def value_of(cls, value: str) -> "InvokeFrom":
def value_of(cls, value: str):
"""
Get value of given mode.
@ -82,7 +82,7 @@ class AppGenerateEntity(BaseModel):
app_config: AppConfig
inputs: Mapping[str, Any]
files: list[FileVar] = []
files: Sequence[File]
user_id: str
# extras

View File

@ -5,9 +5,10 @@ from typing import Any, Optional
from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class QueueEvent(str, Enum):

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional
@ -119,6 +120,7 @@ class MessageEndStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_END
id: str
metadata: dict = {}
files: Optional[Sequence[Mapping[str, Any]]] = None
class MessageFileStreamResponse(StreamResponse):
@ -211,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
created_by: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[list[dict]] = []
files: Optional[Sequence[Mapping[str, Any]]] = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
workflow_run_id: str
@ -296,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse):
execution_metadata: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[list[dict]] = []
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None

View File

@ -1,49 +0,0 @@
from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
ArraySegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from .types import SegmentType
from .variables import (
ArrayAnyVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
)
__all__ = [
"IntegerVariable",
"FloatVariable",
"ObjectVariable",
"SecretVariable",
"StringVariable",
"ArrayAnyVariable",
"Variable",
"SegmentType",
"SegmentGroup",
"Segment",
"NoneSegment",
"NoneVariable",
"IntegerSegment",
"FloatSegment",
"ObjectSegment",
"ArrayAnySegment",
"StringSegment",
"ArrayStringVariable",
"ArrayNumberVariable",
"ArrayObjectVariable",
"ArraySegment",
]

View File

@ -1,2 +0,0 @@
class VariableError(ValueError):
pass

View File

@ -1,76 +0,0 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from .exc import VariableError
from .segments import (
ArrayAnySegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from .types import SegmentType
from .variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
)
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if (value_type := mapping.get("value_type")) is None:
raise VariableError("missing value type")
if not mapping.get("name"):
raise VariableError("missing name")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, int):
result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, float):
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
result = ArrayStringVariable.model_validate(mapping)
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case _:
raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
return result
def build_segment(value: Any, /) -> Segment:
if value is None:
return NoneSegment()
if isinstance(value, str):
return StringSegment(value=value)
if isinstance(value, int):
return IntegerSegment(value=value)
if isinstance(value, float):
return FloatSegment(value=value)
if isinstance(value, dict):
return ObjectSegment(value=value)
if isinstance(value, list):
return ArrayAnySegment(value=value)
raise ValueError(f"not supported value {value}")

View File

@ -1,18 +0,0 @@
import re
from core.workflow.entities.variable_pool import VariablePool
from . import SegmentGroup, factory
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
def convert_template(*, template: str, variable_pool: VariablePool):
parts = re.split(VARIABLE_PATTERN, template)
segments = []
for part in filter(lambda x: x, parts):
if "." in part and (value := variable_pool.get(part.split("."))):
segments.append(value)
else:
segments.append(factory.build_segment(part))
return SegmentGroup(value=segments)

View File

@ -1,22 +0,0 @@
from .segments import Segment
from .types import SegmentType
class SegmentGroup(Segment):
value_type: SegmentType = SegmentType.GROUP
value: list[Segment]
@property
def text(self):
return "".join([segment.text for segment in self.value])
@property
def log(self):
return "".join([segment.log for segment in self.value])
@property
def markdown(self):
return "".join([segment.markdown for segment in self.value])
def to_object(self):
return [segment.to_object() for segment in self.value]

View File

@ -1,126 +0,0 @@
import json
import sys
from collections.abc import Mapping, Sequence
from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from .types import SegmentType
class Segment(BaseModel):
model_config = ConfigDict(frozen=True)
value_type: SegmentType
value: Any
@field_validator("value_type")
@classmethod
def validate_value_type(cls, value):
"""
This validator checks if the provided value is equal to the default value of the 'value_type' field.
If the value is different, a ValueError is raised.
"""
if value != cls.model_fields["value_type"].default:
raise ValueError("Cannot modify 'value_type'")
return value
@property
def text(self) -> str:
return str(self.value)
@property
def log(self) -> str:
return str(self.value)
@property
def markdown(self) -> str:
return str(self.value)
@property
def size(self) -> int:
return sys.getsizeof(self.value)
def to_object(self) -> Any:
return self.value
class NoneSegment(Segment):
value_type: SegmentType = SegmentType.NONE
value: None = None
@property
def text(self) -> str:
return "null"
@property
def log(self) -> str:
return "null"
@property
def markdown(self) -> str:
return "null"
class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING
value: str
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value: float
class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value: int
class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any]
@property
def text(self) -> str:
return json.dumps(self.model_dump()["value"], ensure_ascii=False)
@property
def log(self) -> str:
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
@property
def markdown(self) -> str:
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
class ArraySegment(Segment):
@property
def markdown(self) -> str:
items = []
for item in self.value:
if hasattr(item, "to_markdown"):
items.append(item.to_markdown())
else:
items.append(str(item))
return "\n".join(items)
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any]
class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str]
class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[float | int]
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]

View File

@ -1,15 +0,0 @@
from enum import Enum
class SegmentType(str, Enum):
NONE = "none"
NUMBER = "number"
STRING = "string"
SECRET = "secret"
ARRAY_ANY = "array[any]"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
OBJECT = "object"
GROUP = "group"

View File

@ -1,75 +0,0 @@
from pydantic import Field
from core.helper import encrypter
from .segments import (
ArrayAnySegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from .types import SegmentType
class Variable(Segment):
"""
A variable is a segment that has a name.
"""
id: str = Field(
default="",
description="Unique identity for variable. It's only used by environment variables now.",
)
name: str
description: str = Field(default="", description="Description of the variable.")
class StringVariable(StringSegment, Variable):
pass
class FloatVariable(FloatSegment, Variable):
pass
class IntegerVariable(IntegerSegment, Variable):
pass
class ObjectVariable(ObjectSegment, Variable):
pass
class ArrayAnyVariable(ArrayAnySegment, Variable):
pass
class ArrayStringVariable(ArrayStringSegment, Variable):
pass
class ArrayNumberVariable(ArrayNumberSegment, Variable):
pass
class ArrayObjectVariable(ArrayObjectSegment, Variable):
pass
class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET
@property
def log(self) -> str:
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):
value_type: SegmentType = SegmentType.NONE
value: None = None

View File

@ -53,7 +53,7 @@ class BasedGenerateTaskPipeline:
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None) -> Exception:
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None):
"""
Handle error event.
:param event: event
@ -100,7 +100,7 @@ class BasedGenerateTaskPipeline:
return message
def _error_to_stream_response(self, e: Exception) -> ErrorStreamResponse:
def _error_to_stream_response(self, e: Exception):
"""
Error to stream response.
:param e: exception

View File

@ -1,8 +1,11 @@
import json
import time
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone
from typing import Any, Optional, Union, cast
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
@ -27,27 +30,26 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.file.file_obj import FileVar
from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
CreatedByRole,
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
)
@ -117,7 +119,7 @@ class WorkflowCycleManage:
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Optional[str] = None,
outputs: Mapping[str, Any] | None = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
@ -133,8 +135,10 @@ class WorkflowCycleManage:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs)
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
workflow_run.outputs = outputs
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
@ -230,30 +234,30 @@ class WorkflowCycleManage:
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution:
# init workflow node execution
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
session.add(workflow_node_execution)
session.commit()
session.refresh(workflow_node_execution)
self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
return workflow_node_execution
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
@ -265,6 +269,7 @@ class WorkflowCycleManage:
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
@ -276,7 +281,7 @@ class WorkflowCycleManage:
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.execution_metadata: execution_metadata,
WorkflowNodeExecution.finished_at: finished_at,
@ -286,10 +291,11 @@ class WorkflowCycleManage:
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.finished_at = finished_at
@ -308,6 +314,7 @@ class WorkflowCycleManage:
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
@ -317,7 +324,7 @@ class WorkflowCycleManage:
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
@ -326,11 +333,12 @@ class WorkflowCycleManage:
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
@ -637,7 +645,7 @@ class WorkflowCycleManage:
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
@ -646,15 +654,15 @@ class WorkflowCycleManage:
if not outputs_dict:
return []
files = []
for output_var, output_value in outputs_dict.items():
file_vars = self._fetch_files_from_variable_value(output_value)
if file_vars:
files.extend(file_vars)
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
# Remove None
files = [file for file in files if file]
# Flatten list
files = [file for sublist in files for file in sublist]
return files
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]:
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from variable value
:param value: variable value
@ -666,17 +674,17 @@ class WorkflowCycleManage:
files = []
if isinstance(value, list):
for item in value:
file_var = self._get_file_var_from_value(item)
if file_var:
files.append(file_var)
file = self._get_file_var_from_value(item)
if file:
files.append(file)
elif isinstance(value, dict):
file_var = self._get_file_var_from_value(value)
if file_var:
files.append(file_var)
file = self._get_file_var_from_value(value)
if file:
files.append(file)
return files
def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]:
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value
@ -685,14 +693,11 @@ class WorkflowCycleManage:
if not value:
return None
if isinstance(value, dict):
if "__variant" in value and value["__variant"] == FileVar.__name__:
return value
elif isinstance(value, FileVar):
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
return value
elif isinstance(value, File):
return value.to_dict()
return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run

View File

@ -1,29 +0,0 @@
import enum
from typing import Any
from pydantic import BaseModel
class PromptMessageFileType(enum.Enum):
IMAGE = "image"
@staticmethod
def value_of(value):
for member in PromptMessageFileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class PromptMessageFile(BaseModel):
type: PromptMessageFileType
data: Any = None
class ImagePromptMessageFile(PromptMessageFile):
class DETAIL(enum.Enum):
LOW = "low"
HIGH = "high"
type: PromptMessageFileType = PromptMessageFileType.IMAGE
detail: DETAIL = DETAIL.LOW

View File

@ -0,0 +1,19 @@
from .constants import FILE_MODEL_IDENTITY
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
from .models import (
File,
FileExtraConfig,
ImageConfig,
)
__all__ = [
"FileType",
"FileExtraConfig",
"FileTransferMethod",
"FileBelongsTo",
"File",
"ImageConfig",
"FileAttribute",
"ArrayFileAttribute",
"FILE_MODEL_IDENTITY",
]

View File

@ -1,145 +0,0 @@
import enum
from typing import Any, Optional
from pydantic import BaseModel
from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from extensions.ext_database import db
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class FileType(enum.Enum):
IMAGE = "image"
@staticmethod
def value_of(value):
for member in FileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(enum.Enum):
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
@staticmethod
def value_of(value):
for member in FileTransferMethod:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(enum.Enum):
USER = "user"
ASSISTANT = "assistant"
@staticmethod
def value_of(value):
for member in FileBelongsTo:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileVar(BaseModel):
id: Optional[str] = None # message file id
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
url: Optional[str] = None # remote url
related_id: Optional[str] = None
extra_config: Optional[FileExtraConfig] = None
filename: Optional[str] = None
extension: Optional[str] = None
mime_type: Optional[str] = None
def to_dict(self) -> dict:
return {
"__variant": self.__class__.__name__,
"tenant_id": self.tenant_id,
"type": self.type.value,
"transfer_method": self.transfer_method.value,
"url": self.preview_url,
"remote_url": self.url,
"related_id": self.related_id,
"filename": self.filename,
"extension": self.extension,
"mime_type": self.mime_type,
}
def to_markdown(self) -> str:
"""
Convert file to markdown
:return:
"""
preview_url = self.preview_url
if self.type == FileType.IMAGE:
text = f'![{self.filename or ""}]({preview_url})'
else:
text = f"[{self.filename or preview_url}]({preview_url})"
return text
@property
def data(self) -> Optional[str]:
"""
Get image data, file signed url or base64 data
depending on config MULTIMODAL_SEND_IMAGE_FORMAT
:return:
"""
return self._get_data()
@property
def preview_url(self) -> Optional[str]:
"""
Get signed preview url
:return:
"""
return self._get_data(force_url=True)
@property
def prompt_message_content(self) -> ImagePromptMessageContent:
if self.type == FileType.IMAGE:
image_config = self.extra_config.image_config
return ImagePromptMessageContent(
data=self.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if image_config.get("detail") == "high"
else ImagePromptMessageContent.DETAIL.LOW,
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
from models.model import UploadFile
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = (
db.session.query(UploadFile)
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
.first()
)
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
extension = self.extension
# add sign url
return ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=self.related_id, extension=extension
)
return None

View File

@ -1,243 +0,0 @@
import re
from collections.abc import Mapping, Sequence
from typing import Any, Union
from urllib.parse import parse_qs, urlparse
import requests
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, MessageFile, UploadFile
from services.file_service import IMAGE_EXTENSIONS
class MessageFileParser:
def __init__(self, tenant_id: str, app_id: str) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
) -> list[FileVar]:
"""
validate and transform files arg
:param files:
:param file_extra_config:
:param user:
:return:
"""
for file in files:
if not isinstance(file, dict):
raise ValueError("Invalid file format, must be dict")
if not file.get("type"):
raise ValueError("Missing file type")
FileType.value_of(file.get("type"))
if not file.get("transfer_method"):
raise ValueError("Missing file transfer method")
FileTransferMethod.value_of(file.get("transfer_method"))
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
if not file.get("url"):
raise ValueError("Missing file url")
if not file.get("url").startswith("http"):
raise ValueError("Invalid file url")
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
raise ValueError("Missing file upload_file_id")
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
raise ValueError("Missing file tool_file_id")
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config)
# validate files
new_files = []
for file_type, file_objs in type_file_objs.items():
if file_type == FileType.IMAGE:
# parse and validate files
image_config = file_extra_config.image_config
# check if image file feature is enabled
if not image_config:
continue
# Validate number of files
if len(files) > image_config["number_limits"]:
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
for file_obj in file_objs:
# Validate transfer method
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
# Validate file type
if file_obj.type != FileType.IMAGE:
raise ValueError(f"Invalid file type: {file_obj.type}")
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
# check remote url valid and is image
result, error = self._check_image_remote_url(file_obj.url)
if result is False:
raise ValueError(error)
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
# get upload file from upload_file_id
upload_file = (
db.session.query(UploadFile)
.filter(
UploadFile.id == file_obj.related_id,
UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by == user.id,
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
UploadFile.extension.in_(IMAGE_EXTENSIONS),
)
.first()
)
# check upload file is belong to tenant and user
if not upload_file:
raise ValueError("Invalid upload file")
new_files.append(file_obj)
# return all file objs
return new_files
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
"""
transform message files
:param files:
:param file_extra_config:
:return:
"""
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config)
# return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
) -> dict[FileType, list[FileVar]]:
"""
transform files to file objs
:param files:
:param file_extra_config:
:return:
"""
type_file_objs: dict[FileType, list[FileVar]] = {
# Currently only support image
FileType.IMAGE: []
}
if not files:
return type_file_objs
# group by file type and convert file args or message files to FileObj
for file in files:
if isinstance(file, MessageFile):
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
continue
file_obj = self._to_file_obj(file, file_extra_config)
if file_obj.type not in type_file_objs:
continue
type_file_objs[file_obj.type].append(file_obj)
return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
"""
transform file to file obj
:param file:
:return:
"""
if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
if transfer_method != FileTransferMethod.TOOL_FILE:
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get("type")),
transfer_method=transfer_method,
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=file_extra_config,
)
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get("type")),
transfer_method=transfer_method,
url=None,
related_id=file.get("tool_file_id"),
extra_config=file_extra_config,
)
else:
return FileVar(
id=file.id,
tenant_id=self.tenant_id,
type=FileType.value_of(file.type),
transfer_method=FileTransferMethod.value_of(file.transfer_method),
url=file.url,
related_id=file.upload_file_id or None,
extra_config=file_extra_config,
)
def _check_image_remote_url(self, url):
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/91.0.4472.124 Safari/537.36"
}
def is_s3_presigned_url(url):
try:
parsed_url = urlparse(url)
if "amazonaws.com" not in parsed_url.netloc:
return False
query_params = parse_qs(parsed_url.query)
def check_presign_v2(query_params):
required_params = ["Signature", "Expires"]
for param in required_params:
if param not in query_params:
return False
if not query_params["Expires"][0].isdigit():
return False
signature = query_params["Signature"][0]
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
return False
return True
def check_presign_v4(query_params):
required_params = ["X-Amz-Signature", "X-Amz-Expires"]
for param in required_params:
if param not in query_params:
return False
if not query_params["X-Amz-Expires"][0].isdigit():
return False
signature = query_params["X-Amz-Signature"][0]
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
return False
return True
return check_presign_v4(query_params) or check_presign_v2(query_params)
except Exception:
return False
if is_s3_presigned_url(url):
response = requests.get(url, headers=headers, allow_redirects=True)
if response.status_code in {200, 304}:
return True, ""
response = requests.head(url, headers=headers, allow_redirects=True)
if response.status_code in {200, 304}:
return True, ""
else:
return False, "URL does not exist."
except requests.RequestException as e:
return False, f"Error checking URL: {e}"

View File

@ -1,4 +1,9 @@
tool_file_manager = {"manager": None}
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from core.tools.tool_file_manager import ToolFileManager
tool_file_manager: dict[str, Any] = {"manager": None}
class ToolFileParser:

View File

@ -1,79 +0,0 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from typing import Optional
from configs import dify_config
from extensions.ext_storage import storage
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
class UploadFileParser:
@classmethod
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
if not upload_file:
return None
if upload_file.extension not in IMAGE_EXTENSIONS:
return None
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.error(f"File not found: {upload_file.key}")
return None
encoded_string = base64.b64encode(data).decode("utf-8")
return f"data:{upload_file.mime_type};base64,{encoded_string}"
@classmethod
def get_signed_temp_image_url(cls, upload_file_id) -> str:
"""
get signed url from upload file
:param upload_file: UploadFile object
:return:
"""
base_url = dify_config.FILES_URL
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@classmethod
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
:param upload_file_id: file id
:param timestamp: timestamp
:param nonce: nonce
:param sign: signature
:return:
"""
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

View File

@ -13,8 +13,11 @@ SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
proxies = (
{"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL}
proxy_mounts = (
{
"http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL),
"https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL),
}
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
else None
)
@ -33,11 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
while retries <= max_retries:
try:
if SSRF_PROXY_ALL_URL:
response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
elif proxies:
response = httpx.request(method=method, url=url, proxies=proxies, **kwargs)
with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client:
response = client.request(method=method, url=url, **kwargs)
elif proxy_mounts:
with httpx.Client(mounts=proxy_mounts) as client:
response = client.request(method=method, url=url, **kwargs)
else:
response = httpx.request(method=method, url=url, **kwargs)
with httpx.Client() as client:
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST:
return response

View File

@ -1,8 +1,9 @@
from typing import Optional
from flask import Config, Flask
from flask import Flask
from pydantic import BaseModel
from configs import dify_config
from core.entities.provider_entities import QuotaUnit, RestrictModel
from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType
@ -44,32 +45,30 @@ class HostingConfiguration:
moderation_config: HostedModerationConfig = None
def init_app(self, app: Flask) -> None:
config = app.config
if config.get("EDITION") != "CLOUD":
if dify_config.EDITION != "CLOUD":
return
self.provider_map["azure_openai"] = self.init_azure_openai(config)
self.provider_map["openai"] = self.init_openai(config)
self.provider_map["anthropic"] = self.init_anthropic(config)
self.provider_map["minimax"] = self.init_minimax(config)
self.provider_map["spark"] = self.init_spark(config)
self.provider_map["zhipuai"] = self.init_zhipuai(config)
self.provider_map["azure_openai"] = self.init_azure_openai()
self.provider_map["openai"] = self.init_openai()
self.provider_map["anthropic"] = self.init_anthropic()
self.provider_map["minimax"] = self.init_minimax()
self.provider_map["spark"] = self.init_spark()
self.provider_map["zhipuai"] = self.init_zhipuai()
self.moderation_config = self.init_moderation_config(config)
self.moderation_config = self.init_moderation_config()
@staticmethod
def init_azure_openai(app_config: Config) -> HostingProvider:
def init_azure_openai() -> HostingProvider:
quota_unit = QuotaUnit.TIMES
if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"):
if dify_config.HOSTED_AZURE_OPENAI_ENABLED:
credentials = {
"openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"),
"openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"),
"openai_api_key": dify_config.HOSTED_AZURE_OPENAI_API_KEY,
"openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE,
"base_model_name": "gpt-35-turbo",
}
quotas = []
hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=[
@ -122,31 +121,31 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
def init_openai(self, app_config: Config) -> HostingProvider:
def init_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas = []
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
if dify_config.HOSTED_OPENAI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"),
"openai_api_key": dify_config.HOSTED_OPENAI_API_KEY,
}
if app_config.get("HOSTED_OPENAI_API_BASE"):
credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE")
if dify_config.HOSTED_OPENAI_API_BASE:
credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE
if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"):
credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION")
if dify_config.HOSTED_OPENAI_API_ORGANIZATION:
credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
@ -156,26 +155,26 @@ class HostingConfiguration:
)
@staticmethod
def init_anthropic(app_config: Config) -> HostingProvider:
def init_anthropic() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
quotas = []
if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
quotas.append(trial_quota)
if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
paid_quota = PaidHostingQuota()
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"),
"anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY,
}
if app_config.get("HOSTED_ANTHROPIC_API_BASE"):
credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE")
if dify_config.HOSTED_ANTHROPIC_API_BASE:
credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
@ -185,9 +184,9 @@ class HostingConfiguration:
)
@staticmethod
def init_minimax(app_config: Config) -> HostingProvider:
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_MINIMAX_ENABLED"):
if dify_config.HOSTED_MINIMAX_ENABLED:
quotas = [FreeHostingQuota()]
return HostingProvider(
@ -203,9 +202,9 @@ class HostingConfiguration:
)
@staticmethod
def init_spark(app_config: Config) -> HostingProvider:
def init_spark() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_SPARK_ENABLED"):
if dify_config.HOSTED_SPARK_ENABLED:
quotas = [FreeHostingQuota()]
return HostingProvider(
@ -221,9 +220,9 @@ class HostingConfiguration:
)
@staticmethod
def init_zhipuai(app_config: Config) -> HostingProvider:
def init_zhipuai() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if app_config.get("HOSTED_ZHIPUAI_ENABLED"):
if dify_config.HOSTED_ZHIPUAI_ENABLED:
quotas = [FreeHostingQuota()]
return HostingProvider(
@ -239,17 +238,15 @@ class HostingConfiguration:
)
@staticmethod
def init_moderation_config(app_config: Config) -> HostedModerationConfig:
if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"):
return HostedModerationConfig(
enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",")
)
def init_moderation_config() -> HostedModerationConfig:
if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS:
return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(","))
return HostedModerationConfig(enabled=False)
@staticmethod
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
models_str = app_config.get(env_var)
def parse_restrict_models_from_env(env_var: str) -> list[RestrictModel]:
models_str = dify_config.model_dump().get(env_var)
models_list = models_str.split(",") if models_str else []
return [
RestrictModel(model=model_name.strip(), model_type=ModelType.LLM)

View File

@ -8,6 +8,8 @@ from core.llm_generator.output_parser.suggested_questions_after_answer import Su
from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.model_manager import ModelManager
@ -239,6 +241,54 @@ class LLMGenerator:
return rule_config
@classmethod
def generate_code(
cls,
tenant_id: str,
instruction: str,
model_config: dict,
code_language: str = "javascript",
max_tokens: int = 1000,
) -> dict:
if code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else:
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
prompt = prompt_template.format(
inputs={
"INSTRUCTION": instruction,
"CODE_LANGUAGE": code_language,
},
remove_template_variables=False,
)
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider") if model_config else None,
model=model_config.get("name") if model_config else None,
)
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
try:
response = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
)
generated_code = response.message.content
return {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
error = str(e)
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logging.exception(e)
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
@classmethod
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
prompt = GENERATOR_QA_PROMPT.format(language=document_language)

View File

@ -61,6 +61,73 @@ User Input: yo, 你今天咋样?
User Input:
""" # noqa: E501
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE = (
"You are an expert programmer. Generate code based on the following instructions:\n\n"
"Instructions: {{INSTRUCTION}}\n\n"
"Write the code in {{CODE_LANGUAGE}}.\n\n"
"Please ensure that you meet the following requirements:\n"
"1. Define a function named 'main'.\n"
"2. The 'main' function must return a dictionary (dict).\n"
"3. You may modify the arguments of the 'main' function, but include appropriate type hints.\n"
"4. The returned dictionary should contain at least one key-value pair.\n\n"
"5. You may ONLY use the following libraries in your code: \n"
"- json\n"
"- datetime\n"
"- math\n"
"- random\n"
"- re\n"
"- string\n"
"- sys\n"
"- time\n"
"- traceback\n"
"- uuid\n"
"- os\n"
"- base64\n"
"- hashlib\n"
"- hmac\n"
"- binascii\n"
"- collections\n"
"- functools\n"
"- operator\n"
"- itertools\n\n"
"Example:\n"
"def main(arg1: str, arg2: int) -> dict:\n"
" return {\n"
' "result": arg1 * arg2,\n'
" }\n\n"
"IMPORTANT:\n"
"- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n"
"- DO NOT use markdown code blocks (``` or ``` python). Return the raw code directly.\n"
"- The code should start immediately after this instruction, without any preceding newlines or spaces.\n"
"- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n"
"- Always use the format return {'result': ...} for the output.\n\n"
"Generated Code:\n"
)
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
"You are an expert programmer. Generate code based on the following instructions:\n\n"
"Instructions: {{INSTRUCTION}}\n\n"
"Write the code in {{CODE_LANGUAGE}}.\n\n"
"Please ensure that you meet the following requirements:\n"
"1. Define a function named 'main'.\n"
"2. The 'main' function must return an object.\n"
"3. You may modify the arguments of the 'main' function, but include appropriate JSDoc annotations.\n"
"4. The returned object should contain at least one key-value pair.\n\n"
"5. The returned object should always be in the format: {result: ...}\n\n"
"Example:\n"
"function main(arg1, arg2) {\n"
" return {\n"
" result: arg1 * arg2\n"
" };\n"
"}\n\n"
"IMPORTANT:\n"
"- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n"
"- DO NOT use markdown code blocks (``` or ``` javascript). Return the raw code directly.\n"
"- The code should start immediately after this instruction, without any preceding newlines or spaces.\n"
"- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n"
"Generated Code:\n"
)
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"

View File

@ -1,18 +1,21 @@
from typing import Optional
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file.message_file_parser import MessageFileParser
from core.file import file_manager
from core.file.models import FileType
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageRole,
TextPromptMessageContent,
UserPromptMessage,
)
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from factories import file_factory
from models.model import AppMode, Conversation, Message, MessageFile
from models.workflow import WorkflowRun
@ -65,13 +68,12 @@ class TokenBufferMemory:
messages = list(reversed(thread_messages))
message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
prompt_messages = []
for message in messages:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
else:
if message.workflow_run_id:
@ -84,17 +86,22 @@ class TokenBufferMemory:
workflow_run.workflow.features_dict, is_vision=False
)
if file_extra_config:
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
if file_extra_config and app_record:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
)
else:
file_objs = []
if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query))
else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
if file_obj.type in {FileType.IMAGE, FileType.AUDIO}:
prompt_message = file_manager.to_prompt_message_content(file_obj)
prompt_message_contents.append(prompt_message)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:

View File

@ -1,7 +1,7 @@
import logging
import os
from collections.abc import Callable, Generator, Sequence
from typing import IO, Optional, Union, cast
from collections.abc import Callable, Generator, Iterable, Sequence
from typing import IO, Any, Optional, Union, cast
from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -274,7 +274,7 @@ class ModelInstance:
user=user,
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
"""
Invoke large language tts model
@ -298,7 +298,7 @@ class ModelInstance:
voice=voice,
)
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
"""
Round-robin invoke
:param function: function to invoke

View File

@ -218,7 +218,7 @@ For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` param
However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below:
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
used to define customizable model schema
"""

Some files were not shown because too many files have changed in this diff Show More