diff --git a/.devcontainer/post_start_command.sh b/.devcontainer/post_start_command.sh index e3d5a6d59d..56e87614ba 100755 --- a/.devcontainer/post_start_command.sh +++ b/.devcontainer/post_start_command.sh @@ -1,3 +1,3 @@ #!/bin/bash -poetry install -C api \ No newline at end of file +cd api && poetry install \ No newline at end of file diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index deae361d2f..db74d8b112 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -50,7 +50,7 @@ jobs: echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} @@ -115,7 +115,7 @@ jobs: merge-multiple: true - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} diff --git a/README.md b/README.md index cd783501e2..61bd0d1e26 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,56 @@

+## Table of Content +0. [Quick-Start🚀](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start) + +1. [Intro📖](https://github.com/langgenius/dify?tab=readme-ov-file#intro) + +2. [How to use🔧](https://github.com/langgenius/dify?tab=readme-ov-file#using-dify) + +3. [Stay Ahead🏃](https://github.com/langgenius/dify?tab=readme-ov-file#staying-ahead) + +4. [Next Steps🏹](https://github.com/langgenius/dify?tab=readme-ov-file#next-steps) + +5. [Contributing💪](https://github.com/langgenius/dify?tab=readme-ov-file#contributing) + +6. [Community and Contact🏠](https://github.com/langgenius/dify?tab=readme-ov-file#community--contact) + +7. [Star-History📈](https://github.com/langgenius/dify?tab=readme-ov-file#star-history) + +8. [Security🔒](https://github.com/langgenius/dify?tab=readme-ov-file#security-disclosure) + +9. [License🤝](https://github.com/langgenius/dify?tab=readme-ov-file#license) + +> Make sure you read through this README before you start utilizing Dify😊 + + +## Quick start +The quickest way to deploy Dify locally is to run our [docker-compose.yml](https://github.com/langgenius/dify/blob/main/docker/docker-compose.yaml). Follow the instructions to start in 5 minutes. + +> Before installing Dify, make sure your machine meets the following minimum system requirements: +> +>- CPU >= 2 Core +>- RAM >= 4 GiB +>- Docker and Docker Compose Installed +
+ +Run the following command in your terminal to clone the whole repo. +```bash +git clone https://github.com/langgenius/dify.git +``` +After cloning,run the following command one by one. +```bash +cd dify +cd docker +cp .env.example .env +docker compose up -d +``` + +After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. You will be asked to setup an admin account. +For more info of quick setup, check [here](https://docs.dify.ai/getting-started/install-self-hosted/docker-compose) + +## Intro Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:

@@ -79,73 +129,6 @@ Dify is an open-source LLM app development platform. Its intuitive interface com All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. -## Feature comparison - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Features (SSO/Access control)
Local Deployment
- ## Using Dify - **Cloud
** @@ -166,30 +149,21 @@ Star Dify on GitHub and be instantly notified of new releases. ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - - -## Quick start -> Before installing Dify, make sure your machine meets the following minimum system requirements: -> ->- CPU >= 2 Core ->- RAM >= 4 GiB - -
- -The easiest way to start the Dify server is to run our [docker-compose.yml](docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: - -```bash -cd docker -cp .env.example .env -docker compose up -d -``` - -After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. - -> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) - ## Next steps +Go to [quick-start](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start) to setup your Dify or setup by source code. + +#### If you...... +If you forget your admin account, you can refer to this [guide](https://docs.dify.ai/getting-started/install-self-hosted/faqs#id-4.-how-to-reset-the-password-of-the-admin-account) to reset the password. + +> Use docker compose up without "-d" to enable logs printing out in your terminal. This might be useful if you have encountered unknow problems when using Dify. + +If you encountered system error and would like to acquire help in Github issues, make sure you always paste logs of the error in the request to accerate the conversation. Go to [Community & contact](https://github.com/langgenius/dify?tab=readme-ov-file#community--contact) for more information. + +> Please read the [Dify Documentation](https://docs.dify.ai/) for detailed how-to-use guidance. Most of the potential problems are explained in the doc. + +> If you'd like to contribute to Dify or make additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) + If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. @@ -228,6 +202,7 @@ At the same time, please consider supporting Dify by sharing it on social media * [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). * [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. * [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. +* Make sure a log, if possible, is attached to an error reported to maximize solution efficiency. ## Star history diff --git a/api/Dockerfile b/api/Dockerfile index 1d2c12fee4..c71317f797 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -55,7 +55,12 @@ RUN apt-get update \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends libsqlite3-0=3.46.1-1 \ + && apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \ + && if [ "$(dpkg --print-architecture)" = "amd64" ]; then \ + apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1+b1; \ + else \ + apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1; \ + fi \ # install a chinese font to support the use of tools like matplotlib && apt-get install -y fonts-noto-cjk \ && apt-get autoremove -y \ diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 3e11d5fe6b..f88019fbb6 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -10,7 +10,6 @@ from pydantic import ( PositiveInt, computed_field, ) -from pydantic_extra_types.timezone_name import TimeZoneName from pydantic_settings import BaseSettings from configs.feature.hosted_service import HostedServiceConfig @@ -393,9 +392,8 @@ class LoggingConfig(BaseSettings): default=None, ) - LOG_TZ: Optional[TimeZoneName] = Field( - description="Timezone for log timestamps. Allowed timezone values can be referred to IANA Time Zone Database," - " e.g., 'America/New_York')", + LOG_TZ: Optional[str] = Field( + description="Timezone for log timestamps (e.g., 'America/New_York')", default=None, ) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 38bb804613..4be761747d 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -16,6 +16,7 @@ from configs.middleware.storage.supabase_storage_config import SupabaseStorageCo from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig +from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig from configs.middleware.vdb.chroma_config import ChromaConfig from configs.middleware.vdb.couchbase_config import CouchbaseConfig from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig @@ -259,5 +260,6 @@ class MiddlewareConfig( UpstashConfig, TidbOnQdrantConfig, OceanBaseVectorConfig, + BaiduVectorDBConfig, ): pass diff --git a/api/controllers/common/errors.py b/api/controllers/common/errors.py new file mode 100644 index 0000000000..c71f1ce5a3 --- /dev/null +++ b/api/controllers/common/errors.py @@ -0,0 +1,6 @@ +from werkzeug.exceptions import HTTPException + + +class FilenameNotExistsError(HTTPException): + code = 400 + description = "The specified filename does not exist." diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py new file mode 100644 index 0000000000..ed24b265ef --- /dev/null +++ b/api/controllers/common/helpers.py @@ -0,0 +1,58 @@ +import mimetypes +import os +import re +import urllib.parse +from uuid import uuid4 + +import httpx +from pydantic import BaseModel + + +class FileInfo(BaseModel): + filename: str + extension: str + mimetype: str + size: int + + +def guess_file_info_from_response(response: httpx.Response): + url = str(response.url) + # Try to extract filename from URL + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + + # If filename couldn't be extracted, use Content-Disposition header + if not filename: + content_disposition = response.headers.get("Content-Disposition") + if content_disposition: + filename_match = re.search(r'filename="?(.+)"?', content_disposition) + if filename_match: + filename = filename_match.group(1) + + # If still no filename, generate a unique one + if not filename: + unique_name = str(uuid4()) + filename = f"{unique_name}" + + # Guess MIME type from filename first, then URL + mimetype, _ = mimetypes.guess_type(filename) + if mimetype is None: + mimetype, _ = mimetypes.guess_type(url) + if mimetype is None: + # If guessing fails, use Content-Type from response headers + mimetype = response.headers.get("Content-Type", "application/octet-stream") + + extension = os.path.splitext(filename)[1] + + # Ensure filename has an extension + if not extension: + extension = mimetypes.guess_extension(mimetype) or ".bin" + filename = f"{filename}{extension}" + + return FileInfo( + filename=filename, + extension=extension, + mimetype=mimetype, + size=int(response.headers.get("Content-Length", -1)), + ) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 8198e9d0ff..9d0dd3fb23 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -2,9 +2,21 @@ from flask import Blueprint from libs.external_api import ExternalApi +from .files import FileApi, FilePreviewApi, FileSupportTypeApi +from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi + bp = Blueprint("console", __name__, url_prefix="/console/api") api = ExternalApi(bp) +# File +api.add_resource(FileApi, "/files/upload") +api.add_resource(FilePreviewApi, "/files//preview") +api.add_resource(FileSupportTypeApi, "/files/support-type") + +# Remote files +api.add_resource(RemoteFileInfoApi, "/remote-files/") +api.add_resource(RemoteFileUploadApi, "/remote-files/upload") + # Import other controllers from . import admin, apikey, extension, feature, ping, setup, version @@ -43,7 +55,6 @@ from .datasets import ( datasets_document, datasets_segments, external, - file, hit_testing, website, ) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index e014964bf9..b612f7bd96 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -12,8 +12,7 @@ from models.dataset import Dataset from models.model import ApiToken, App from . import api -from .setup import setup_required -from .wraps import account_initialization_required +from .wraps import account_initialization_required, setup_required api_key_fields = { "id": fields.String, diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index e7346bdf1d..c228743fa5 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,8 +1,7 @@ from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.advanced_prompt_template_service import AdvancedPromptTemplateService diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index 51899da705..d433415894 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.helper import uuid_value from libs.login import login_required from models.model import AppMode diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 1ea1c82679..fd05cbc19b 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -6,8 +6,11 @@ from werkzeug.exceptions import Forbidden from controllers.console import api from controllers.console.app.error import NoFileUploadedError from controllers.console.datasets.error import TooManyFilesError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from extensions.ext_redis import redis_client from fields.annotation_fields import ( annotation_fields, diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 1b46a3a7d3..36338cbd8a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -6,8 +6,11 @@ from werkzeug.exceptions import BadRequest, Forbidden, abort from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from core.ops.ops_trace_manager import OpsTraceManager from fields.app_fields import ( app_detail_fields, diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index c1ef05a488..112446613f 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -18,8 +18,7 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index d3296d3dff..9896fcaab8 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -15,8 +15,7 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, ) from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index b60a424d98..7b78f622b9 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -10,8 +10,7 @@ from werkzeug.exceptions import Forbidden, NotFound from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 23b234dac9..d49f433ba1 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -4,8 +4,7 @@ from sqlalchemy.orm import Session from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.conversation_variable_fields import paginated_conversation_variable_fields from libs.login import login_required diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 7108759b0b..9c3cbe4e3e 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -10,8 +10,7 @@ from controllers.console.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index fe06201982..b7a4c31a15 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -14,8 +14,11 @@ from controllers.console.app.error import ( ) from controllers.console.app.wraps import get_app_model from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index f5068a4cd8..8ba195f5a5 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -6,8 +6,7 @@ from flask_restful import Resource from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.agent.entities import AgentToolEntity from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 374bd2b815..47b58396a1 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.ops_service import OpsService diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 115a832da9..2f5645852f 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.login import login_required diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 3ef442812d..db5e282409 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -8,8 +8,7 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from libs.helper import DatetimeString from libs.login import login_required diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 1ffdceb2c8..75354218c4 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -9,8 +9,7 @@ import services from controllers.console import api from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from factories import variable_factory diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 629b7a8bf4..2940556f84 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -3,8 +3,7 @@ from flask_restful.inputs import int_range from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs.login import login_required from models import App diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 5824ead9c3..08ab61bbb9 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -3,8 +3,7 @@ from flask_restful.inputs import int_range from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.workflow_run_fields import ( advanced_chat_workflow_run_pagination_fields, workflow_run_detail_fields, diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index f46af0f1ca..6c7c73707b 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -8,8 +8,7 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from libs.helper import DatetimeString from libs.login import login_required diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 50db6eebc1..465c44e9b6 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -7,8 +7,7 @@ from controllers.console.auth.error import ApiKeyAuthFailedError from libs.login import login_required from services.auth.api_key_auth_service import ApiKeyAuthService -from ..setup import setup_required -from ..wraps import account_initialization_required +from ..wraps import account_initialization_required, setup_required class ApiKeyAuthDataSource(Resource): diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index fd31e5ccc3..3c3f45260a 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -11,8 +11,7 @@ from controllers.console import api from libs.login import login_required from libs.oauth_data_source import NotionOAuth -from ..setup import setup_required -from ..wraps import account_initialization_required +from ..wraps import account_initialization_required, setup_required def get_oauth_providers(): diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 3c2de4612f..0cc115d0ee 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -15,7 +15,7 @@ from controllers.console.auth.error import ( PasswordMismatchError, ) from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister -from controllers.console.setup import setup_required +from controllers.console.wraps import setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import email, extract_remote_ip diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 6c795f95b6..e2e8f84920 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -20,7 +20,7 @@ from controllers.console.error import ( NotAllowedCreateWorkspace, NotAllowedRegister, ) -from controllers.console.setup import setup_required +from controllers.console.wraps import setup_required from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip from libs.password import valid_password diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 9a1d914869..4b0c82ae6c 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -2,8 +2,7 @@ from flask_login import current_user from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, only_edition_cloud +from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from libs.login import login_required from services.billing_service import BillingService diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index f024e3799c..06fb3a0a31 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -9,8 +9,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.indexing_runner import IndexingRunner from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.notion_extractor import NotionExtractor diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 4f4d186edd..07ef0ce3e5 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -10,8 +10,7 @@ from controllers.console import api from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner from core.model_runtime.entities.model_entities import ModelType diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 31b4f7b741..521805a651 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -23,8 +23,11 @@ from controllers.console.datasets.error import ( InvalidActionError, InvalidMetadataError, ) -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from core.errors.error import ( LLMBadRequestError, ModelCurrentlyNotSupportError, diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 08ea414288..5d8d664e41 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -11,11 +11,11 @@ import services from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError -from controllers.console.setup import setup_required from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_knowledge_limit_check, cloud_edition_billing_resource_check, + setup_required, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 2dc054cfbd..bc6e3687c1 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -6,8 +6,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from controllers.console import api from controllers.console.datasets.error import DatasetNameDuplicateError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.dataset_fields import dataset_detail_fields from libs.login import login_required from services.dataset_service import DatasetService diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 5c9bcef84c..495f511275 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -2,8 +2,7 @@ from flask_restful import Resource from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index e80ce17c68..9127c8af45 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.website_service import WebsiteService diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 5d6a8bf152..4ac0aa497e 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -3,8 +3,7 @@ from flask_restful import Resource, marshal_with, reqparse from constants import HIDDEN_VALUE from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import login_required from models.api_based_extension import APIBasedExtension diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index f0482f749d..70ab4ff865 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -5,8 +5,7 @@ from libs.login import login_required from services.feature_service import FeatureService from . import api -from .setup import setup_required -from .wraps import account_initialization_required, cloud_utm_record +from .wraps import account_initialization_required, cloud_utm_record, setup_required class FeatureApi(Resource): diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/files/__init__.py similarity index 57% rename from api/controllers/console/datasets/file.py rename to api/controllers/console/files/__init__.py index 17d2879875..69ee7eaabd 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/files/__init__.py @@ -1,25 +1,26 @@ -import urllib.parse - from flask import request from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, marshal_with import services from configs import dify_config from constants import DOCUMENT_EXTENSIONS -from controllers.console import api -from controllers.console.datasets.error import ( +from controllers.common.errors import FilenameNotExistsError +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) +from fields.file_fields import file_fields, upload_config_fields +from libs.login import login_required +from services.file_service import FileService + +from .errors import ( FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError, ) -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.helper import ssrf_proxy -from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields -from libs.login import login_required -from services.file_service import FileService PREVIEW_WORDS_LIMIT = 3000 @@ -44,21 +45,29 @@ class FileApi(Resource): @marshal_with(file_fields) @cloud_edition_billing_resource_check("documents") def post(self): - # get file from request file = request.files["file"] + source = request.form.get("source") - parser = reqparse.RequestParser() - parser.add_argument("source", type=str, required=False, location="args") - source = parser.parse_args().get("source") - - # check file if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() + + if not file.filename: + raise FilenameNotExistsError + + if source not in ("datasets", None): + source = None + try: - upload_file = FileService.upload_file(file=file, user=current_user, source=source) + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source=source, + ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: @@ -83,23 +92,3 @@ class FileSupportTypeApi(Resource): @account_initialization_required def get(self): return {"allowed_extensions": DOCUMENT_EXTENSIONS} - - -class RemoteFileInfoApi(Resource): - @marshal_with(remote_file_info_fields) - def get(self, url): - decoded_url = urllib.parse.unquote(url) - try: - response = ssrf_proxy.head(decoded_url) - return { - "file_type": response.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(response.headers.get("Content-Length", 0)), - } - except Exception as e: - return {"error": str(e)}, 400 - - -api.add_resource(FileApi, "/files/upload") -api.add_resource(FilePreviewApi, "/files//preview") -api.add_resource(FileSupportTypeApi, "/files/support-type") -api.add_resource(RemoteFileInfoApi, "/remote-files/") diff --git a/api/controllers/console/files/errors.py b/api/controllers/console/files/errors.py new file mode 100644 index 0000000000..1654ef2cf4 --- /dev/null +++ b/api/controllers/console/files/errors.py @@ -0,0 +1,25 @@ +from libs.exception import BaseHTTPException + + +class FileTooLargeError(BaseHTTPException): + error_code = "file_too_large" + description = "File size exceeded. {message}" + code = 413 + + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = "unsupported_file_type" + description = "File type not allowed." + code = 415 + + +class TooManyFilesError(BaseHTTPException): + error_code = "too_many_files" + description = "Only one file is allowed." + code = 400 + + +class NoFileUploadedError(BaseHTTPException): + error_code = "no_file_uploaded" + description = "Please upload your file." + code = 400 diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py new file mode 100644 index 0000000000..42d6e25416 --- /dev/null +++ b/api/controllers/console/remote_files.py @@ -0,0 +1,71 @@ +import urllib.parse +from typing import cast + +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse + +from controllers.common import helpers +from core.file import helpers as file_helpers +from core.helper import ssrf_proxy +from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from models.account import Account +from services.file_service import FileService + + +class RemoteFileInfoApi(Resource): + @marshal_with(remote_file_info_fields) + def get(self, url): + decoded_url = urllib.parse.unquote(url) + try: + response = ssrf_proxy.head(decoded_url) + return { + "file_type": response.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(response.headers.get("Content-Length", 0)), + } + except Exception as e: + return {"error": str(e)}, 400 + + +class RemoteFileUploadApi(Resource): + @marshal_with(file_fields_with_signed_url) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("url", type=str, required=True, help="URL is required") + args = parser.parse_args() + + url = args["url"] + + response = ssrf_proxy.head(url) + response.raise_for_status() + + file_info = helpers.guess_file_info_from_response(response) + + if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): + return {"error": "File size exceeded"}, 400 + + response = ssrf_proxy.get(url) + response.raise_for_status() + content = response.content + + try: + user = cast(Account, current_user) + upload_file = FileService.upload_file( + filename=file_info.filename, + content=content, + mimetype=file_info.mimetype, + user=user, + source_url=url, + ) + except Exception as e: + return {"error": str(e)}, 400 + + return { + "id": upload_file.id, + "name": upload_file.name, + "size": upload_file.size, + "extension": upload_file.extension, + "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + "mime_type": upload_file.mime_type, + "created_by": upload_file.created_by, + "created_at": upload_file.created_at, + }, 201 diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index d229bb2a19..e1f19a87a3 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,3 @@ -from functools import wraps - from flask import request from flask_restful import Resource, reqparse @@ -10,7 +8,7 @@ from models.model import DifySetup, db from services.account_service import RegisterService, TenantService from . import api -from .error import AlreadySetupError, NotInitValidateError, NotSetupError +from .error import AlreadySetupError, NotInitValidateError from .init_validate import get_init_validate_status from .wraps import only_edition_self_hosted @@ -52,21 +50,6 @@ class SetupApi(Resource): return {"result": "success"}, 201 -def setup_required(view): - @wraps(view) - def decorated(*args, **kwargs): - # check setup - if not get_init_validate_status(): - raise NotInitValidateError() - - elif not get_setup_status(): - raise NotSetupError() - - return view(*args, **kwargs) - - return decorated - - def get_setup_status(): if dify_config.EDITION == "SELF_HOSTED": return db.session.query(DifySetup).first() diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index de30547e93..ccd3293a62 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -4,8 +4,7 @@ from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.tag_fields import tag_fields from libs.login import login_required from models.model import Tag diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index deda1a0d02..7dea8e554e 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -3,6 +3,7 @@ import logging import requests from flask_restful import Resource, reqparse +from packaging import version from configs import dify_config @@ -47,43 +48,15 @@ class VersionApi(Resource): def _has_new_version(*, latest_version: str, current_version: str) -> bool: - def parse_version(version: str) -> tuple: - # Split version into parts and pre-release suffix if any - parts = version.split("-") - version_parts = parts[0].split(".") - pre_release = parts[1] if len(parts) > 1 else None + try: + latest = version.parse(latest_version) + current = version.parse(current_version) - # Validate version format - if len(version_parts) != 3: - raise ValueError(f"Invalid version format: {version}") - - try: - # Convert version parts to integers - major, minor, patch = map(int, version_parts) - return (major, minor, patch, pre_release) - except ValueError: - raise ValueError(f"Invalid version format: {version}") - - latest = parse_version(latest_version) - current = parse_version(current_version) - - # Compare major, minor, and patch versions - for latest_part, current_part in zip(latest[:3], current[:3]): - if latest_part > current_part: - return True - elif latest_part < current_part: - return False - - # If versions are equal, check pre-release suffixes - if latest[3] is None and current[3] is not None: - return True - elif latest[3] is not None and current[3] is None: + # Compare versions + return latest > current + except version.InvalidVersion: + logging.warning(f"Invalid version format: latest={latest_version}, current={current_version}") return False - elif latest[3] is not None and current[3] is not None: - # Simple string comparison for pre-release versions - return latest[3] > current[3] - - return False api.add_resource(VersionApi, "/version") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 97f5625726..aabc417759 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -8,14 +8,13 @@ from flask_restful import Resource, fields, marshal_with, reqparse from configs import dify_config from constants.languages import supported_language from controllers.console import api -from controllers.console.setup import setup_required from controllers.console.workspace.error import ( AccountAlreadyInitedError, CurrentPasswordIncorrectError, InvalidInvitationCodeError, RepeatPasswordNotMatchError, ) -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.member_fields import account_fields from libs.helper import TimestampField, timezone diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 9c9c76c9f4..0aa66bfd6e 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -3,8 +3,7 @@ from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from libs.login import login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 9d2697f11d..114905cf1d 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_user, login_required diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 3e87bebf59..8f694c65e0 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -4,8 +4,11 @@ from flask_restful import Resource, abort, marshal_with, reqparse import services from configs import dify_config from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from extensions.ext_database import db from fields.member_fields import account_with_role_list_fields from libs.login import login_required diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index b9f13e3ce4..b9612c0f9d 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -6,8 +6,7 @@ from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 7bbedc8828..daa3455e2f 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -5,8 +5,7 @@ from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index a25d906528..1cb83c136f 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -7,9 +7,8 @@ from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from controllers.console.setup import setup_required from controllers.console.workspace import plugin_permission_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from libs.login import login_required from models.account import TenantPluginPermission diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 3959e59307..910b991de1 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import login_required diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 96f866fca2..76d76f6b58 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -6,6 +6,7 @@ from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqpa from werkzeug.exceptions import Unauthorized import services +from controllers.common.errors import FilenameNotExistsError from controllers.console import api from controllers.console.admin import admin_required from controllers.console.datasets.error import ( @@ -15,8 +16,11 @@ from controllers.console.datasets.error import ( UnsupportedFileTypeError, ) from controllers.console.error import AccountNotLinkTenantError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required @@ -193,12 +197,20 @@ class WebappLogoWorkspaceApi(Resource): if len(request.files) > 1: raise TooManyFilesError() + if not file.filename: + raise FilenameNotExistsError + extension = file.filename.split(".")[-1] if extension.lower() not in {"svg", "png"}: raise UnsupportedFileTypeError() try: - upload_file = FileService.upload_file(file=file, user=current_user) + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 46223d104f..291e2500aa 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,4 +1,5 @@ import json +import os from functools import wraps from flask import abort, request @@ -6,9 +7,13 @@ from flask_login import current_user from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError +from extensions.ext_database import db +from models.model import DifySetup from services.feature_service import FeatureService from services.operation_service import OperationService +from .error import NotInitValidateError, NotSetupError + def account_initialization_required(view): @wraps(view) @@ -124,3 +129,21 @@ def cloud_utm_record(view): return view(*args, **kwargs) return decorated + + +def setup_required(view): + @wraps(view) + def decorated(*args, **kwargs): + # check setup + if ( + dify_config.EDITION == "SELF_HOSTED" + and os.environ.get("INIT_PASSWORD") + and not db.session.query(DifySetup).first() + ): + raise NotInitValidateError() + elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first(): + raise NotSetupError() + + return view(*args, **kwargs) + + return decorated diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 7a980d6e39..e507c084a9 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,6 +1,6 @@ from flask_restful import Resource -from controllers.console.setup import setup_required +from controllers.console.wraps import setup_required from controllers.inner_api import api from controllers.inner_api.plugin.wraps import get_tenant, plugin_data from controllers.inner_api.wraps import plugin_inner_api_only diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 3feb3452aa..64cb5e54ff 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,6 +1,6 @@ from flask_restful import Resource, reqparse -from controllers.console.setup import setup_required +from controllers.console.wraps import setup_required from controllers.inner_api import api from controllers.inner_api.wraps import enterprise_inner_api_only from events.tenant_event import tenant_was_created diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index e0a772eb31..b0126058de 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -2,6 +2,7 @@ from flask import request from flask_restful import Resource, marshal_with import services +from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api from controllers.service_api.app.error import ( FileTooLargeError, @@ -31,8 +32,17 @@ class FileApi(Resource): if len(request.files) > 1: raise TooManyFilesError() + if not file.filename: + raise FilenameNotExistsError + try: - upload_file = FileService.upload_file(file, end_user) + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=end_user, + source="datasets", + ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 0a0a38c4c6..5c3fc7b241 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -6,6 +6,7 @@ from sqlalchemy import desc from werkzeug.exceptions import NotFound import services.dataset_service +from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.dataset.error import ( @@ -55,7 +56,12 @@ class DocumentAddByTextApi(DatasetApiResource): if not dataset.indexing_technique and not args["indexing_technique"]: raise ValueError("indexing_technique is required.") - upload_file = FileService.upload_text(args.get("text"), args.get("name")) + text = args.get("text") + name = args.get("name") + if text is None or name is None: + raise ValueError("Both 'text' and 'name' must be non-null values.") + + upload_file = FileService.upload_text(text=str(text), text_name=str(name)) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, @@ -104,7 +110,11 @@ class DocumentUpdateByTextApi(DatasetApiResource): raise ValueError("Dataset is not exist.") if args["text"]: - upload_file = FileService.upload_text(args.get("text"), args.get("name")) + text = args.get("text") + name = args.get("name") + if text is None or name is None: + raise ValueError("Both text and name must be strings.") + upload_file = FileService.upload_text(text=str(text), text_name=str(name)) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, @@ -163,7 +173,16 @@ class DocumentAddByFileApi(DatasetApiResource): if len(request.files) > 1: raise TooManyFilesError() - upload_file = FileService.upload_file(file, current_user) + if not file.filename: + raise FilenameNotExistsError + + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} args["data_source"] = data_source # validate args @@ -212,7 +231,16 @@ class DocumentUpdateByFileApi(DatasetApiResource): if len(request.files) > 1: raise TooManyFilesError() - upload_file = FileService.upload_file(file, current_user) + if not file.filename: + raise FilenameNotExistsError + + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} args["data_source"] = data_source # validate args @@ -331,10 +359,26 @@ class DocumentIndexingStatusApi(DatasetApiResource): return data -api.add_resource(DocumentAddByTextApi, "/datasets//document/create_by_text") -api.add_resource(DocumentAddByFileApi, "/datasets//document/create_by_file") -api.add_resource(DocumentUpdateByTextApi, "/datasets//documents//update_by_text") -api.add_resource(DocumentUpdateByFileApi, "/datasets//documents//update_by_file") +api.add_resource( + DocumentAddByTextApi, + "/datasets//document/create_by_text", + "/datasets//document/create-by-text", +) +api.add_resource( + DocumentAddByFileApi, + "/datasets//document/create_by_file", + "/datasets//document/create-by-file", +) +api.add_resource( + DocumentUpdateByTextApi, + "/datasets//documents//update_by_text", + "/datasets//documents//update-by-text", +) +api.add_resource( + DocumentUpdateByFileApi, + "/datasets//documents//update_by_file", + "/datasets//documents//update-by-file", +) api.add_resource(DocumentDeleteApi, "/datasets//documents/") api.add_resource(DocumentListApi, "/datasets//documents") api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 9c9a4302c9..465f71bf03 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -14,4 +14,4 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): return self.perform_hit_testing(dataset, args) -api.add_resource(HitTestingApi, "/datasets//hit-testing") +api.add_resource(HitTestingApi, "/datasets//hit-testing", "/datasets//retrieve") diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 630b9468a7..50a04a6254 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -2,8 +2,17 @@ from flask import Blueprint from libs.external_api import ExternalApi +from .files import FileApi +from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi + bp = Blueprint("web", __name__, url_prefix="/api") api = ExternalApi(bp) +# Files +api.add_resource(FileApi, "/files/upload") -from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow +# Remote files +api.add_resource(RemoteFileInfoApi, "/remote-files/") +api.add_resource(RemoteFileUploadApi, "/remote-files/upload") + +from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py deleted file mode 100644 index 6eeaa0e3f0..0000000000 --- a/api/controllers/web/file.py +++ /dev/null @@ -1,56 +0,0 @@ -import urllib.parse - -from flask import request -from flask_restful import marshal_with, reqparse - -import services -from controllers.web import api -from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError -from controllers.web.wraps import WebApiResource -from core.helper import ssrf_proxy -from fields.file_fields import file_fields, remote_file_info_fields -from services.file_service import FileService - - -class FileApi(WebApiResource): - @marshal_with(file_fields) - def post(self, app_model, end_user): - # get file from request - file = request.files["file"] - - parser = reqparse.RequestParser() - parser.add_argument("source", type=str, required=False, location="args") - source = parser.parse_args().get("source") - - # check file - if "file" not in request.files: - raise NoFileUploadedError() - - if len(request.files) > 1: - raise TooManyFilesError() - try: - upload_file = FileService.upload_file(file=file, user=end_user, source=source) - except services.errors.file.FileTooLargeError as file_too_large_error: - raise FileTooLargeError(file_too_large_error.description) - except services.errors.file.UnsupportedFileTypeError: - raise UnsupportedFileTypeError() - - return upload_file, 201 - - -class RemoteFileInfoApi(WebApiResource): - @marshal_with(remote_file_info_fields) - def get(self, url): - decoded_url = urllib.parse.unquote(url) - try: - response = ssrf_proxy.head(decoded_url) - return { - "file_type": response.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(response.headers.get("Content-Length", -1)), - } - except Exception as e: - return {"error": str(e)}, 400 - - -api.add_resource(FileApi, "/files/upload") -api.add_resource(RemoteFileInfoApi, "/remote-files/") diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py new file mode 100644 index 0000000000..a282fc63a8 --- /dev/null +++ b/api/controllers/web/files.py @@ -0,0 +1,43 @@ +from flask import request +from flask_restful import marshal_with + +import services +from controllers.common.errors import FilenameNotExistsError +from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError +from controllers.web.wraps import WebApiResource +from fields.file_fields import file_fields +from services.file_service import FileService + + +class FileApi(WebApiResource): + @marshal_with(file_fields) + def post(self, app_model, end_user): + file = request.files["file"] + source = request.form.get("source") + + if "file" not in request.files: + raise NoFileUploadedError() + + if len(request.files) > 1: + raise TooManyFilesError() + + if not file.filename: + raise FilenameNotExistsError + + if source not in ("datasets", None): + source = None + + try: + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=end_user, + source=source, + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return upload_file, 201 diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py new file mode 100644 index 0000000000..cb529340af --- /dev/null +++ b/api/controllers/web/remote_files.py @@ -0,0 +1,69 @@ +import urllib.parse + +from flask_login import current_user +from flask_restful import marshal_with, reqparse + +from controllers.common import helpers +from controllers.web.wraps import WebApiResource +from core.file import helpers as file_helpers +from core.helper import ssrf_proxy +from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from services.file_service import FileService + + +class RemoteFileInfoApi(WebApiResource): + @marshal_with(remote_file_info_fields) + def get(self, url): + decoded_url = urllib.parse.unquote(url) + try: + response = ssrf_proxy.head(decoded_url) + return { + "file_type": response.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(response.headers.get("Content-Length", -1)), + } + except Exception as e: + return {"error": str(e)}, 400 + + +class RemoteFileUploadApi(WebApiResource): + @marshal_with(file_fields_with_signed_url) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("url", type=str, required=True, help="URL is required") + args = parser.parse_args() + + url = args["url"] + + response = ssrf_proxy.head(url) + response.raise_for_status() + + file_info = helpers.guess_file_info_from_response(response) + + if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): + return {"error": "File size exceeded"}, 400 + + response = ssrf_proxy.get(url) + response.raise_for_status() + content = response.content + + try: + upload_file = FileService.upload_file( + filename=file_info.filename, + content=content, + mimetype=file_info.mimetype, + user=current_user, + source_url=url, + ) + except Exception as e: + return {"error": str(e)}, 400 + + return { + "id": upload_file.id, + "name": upload_file.name, + "size": upload_file.size, + "extension": upload_file.extension, + "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + "mime_type": upload_file.mime_type, + "created_by": upload_file.created_by, + "created_at": upload_file.created_at, + }, 201 diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 8df26172b7..fb9fe8f210 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -17,6 +17,7 @@ from core.errors.error import ProviderTokenNotInitError from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType +from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -597,26 +598,9 @@ class IndexingRunner: rules = DatasetProcessRule.AUTOMATIC_RULES else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} + document_text = CleanProcessor.clean(text, rules) - if "pre_processing_rules" in rules: - pre_processing_rules = rules["pre_processing_rules"] - for pre_processing_rule in pre_processing_rules: - if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: - # Remove extra spaces - pattern = r"\n{3,}" - text = re.sub(pattern, "\n\n", text) - pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}" - text = re.sub(pattern, " ", text) - elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: - # Remove email - pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" - text = re.sub(pattern, "", text) - - # Remove URL - pattern = r"https?://[^\s]+" - text = re.sub(pattern, "", text) - - return text + return document_text @staticmethod def format_split_text(text): diff --git a/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml b/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml deleted file mode 100644 index aca9456313..0000000000 --- a/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml +++ /dev/null @@ -1,9 +0,0 @@ -- claude-3-5-sonnet-20241022 -- claude-3-5-sonnet-20240620 -- claude-3-haiku-20240307 -- claude-3-opus-20240229 -- claude-3-sonnet-20240229 -- claude-2.1 -- claude-instant-1.2 -- claude-2 -- claude-instant-1 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml deleted file mode 100644 index e20b8c4960..0000000000 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml +++ /dev/null @@ -1,39 +0,0 @@ -model: claude-3-5-sonnet-20241022 -label: - en_US: claude-3-5-sonnet-20241022 -model_type: llm -features: - - agent-thought - - vision - - tool-call - - stream-tool-call -model_properties: - mode: chat - context_size: 200000 -parameter_rules: - - name: temperature - use_template: temperature - - name: top_p - use_template: top_p - - name: top_k - label: - zh_Hans: 取样数量 - en_US: Top k - type: int - help: - zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 - en_US: Only sample from the top K options for each subsequent token. - required: false - - name: max_tokens - use_template: max_tokens - required: true - default: 8192 - min: 1 - max: 8192 - - name: response_format - use_template: response_format -pricing: - input: '3.00' - output: '15.00' - unit: '0.000001' - currency: USD diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml deleted file mode 100644 index 1ef5e83abc..0000000000 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ /dev/null @@ -1,245 +0,0 @@ -provider: azure_openai -label: - en_US: Azure OpenAI Service Model -icon_small: - en_US: icon_s_en.svg -icon_large: - en_US: icon_l_en.png -background: "#E3F0FF" -help: - title: - en_US: Get your API key from Azure - zh_Hans: 从 Azure 获取 API Key - url: - en_US: https://azure.microsoft.com/en-us/products/ai-services/openai-service -supported_model_types: - - llm - - text-embedding - - speech2text - - tts -configurate_methods: - - customizable-model -model_credential_schema: - model: - label: - en_US: Deployment Name - zh_Hans: 部署名称 - placeholder: - en_US: Enter your Deployment Name here, matching the Azure deployment name. - zh_Hans: 在此输入您的部署名称,与 Azure 部署名称匹配。 - credential_form_schemas: - - variable: openai_api_base - label: - en_US: API Endpoint URL - zh_Hans: API 域名 - type: text-input - required: true - placeholder: - zh_Hans: '在此输入您的 API 域名,如:https://example.com/xxx' - en_US: 'Enter your API Endpoint, eg: https://example.com/xxx' - - variable: openai_api_key - label: - en_US: API Key - zh_Hans: API Key - type: secret-input - required: true - placeholder: - zh_Hans: 在此输入您的 API Key - en_US: Enter your API key here - - variable: openai_api_version - label: - zh_Hans: API 版本 - en_US: API Version - type: select - required: true - options: - - label: - en_US: 2024-10-01-preview - value: 2024-10-01-preview - - label: - en_US: 2024-09-01-preview - value: 2024-09-01-preview - - label: - en_US: 2024-08-01-preview - value: 2024-08-01-preview - - label: - en_US: 2024-07-01-preview - value: 2024-07-01-preview - - label: - en_US: 2024-05-01-preview - value: 2024-05-01-preview - - label: - en_US: 2024-04-01-preview - value: 2024-04-01-preview - - label: - en_US: 2024-03-01-preview - value: 2024-03-01-preview - - label: - en_US: 2024-02-15-preview - value: 2024-02-15-preview - - label: - en_US: 2023-12-01-preview - value: 2023-12-01-preview - - label: - en_US: '2024-02-01' - value: '2024-02-01' - - label: - en_US: '2024-06-01' - value: '2024-06-01' - placeholder: - zh_Hans: 在此选择您的 API 版本 - en_US: Select your API Version here - - variable: base_model_name - label: - en_US: Base Model - zh_Hans: 基础模型 - type: select - required: true - options: - - label: - en_US: gpt-35-turbo - value: gpt-35-turbo - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-35-turbo-0125 - value: gpt-35-turbo-0125 - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-35-turbo-16k - value: gpt-35-turbo-16k - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4 - value: gpt-4 - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4-32k - value: gpt-4-32k - show_on: - - variable: __model_type - value: llm - - label: - en_US: o1-mini - value: o1-mini - show_on: - - variable: __model_type - value: llm - - label: - en_US: o1-preview - value: o1-preview - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4o-mini - value: gpt-4o-mini - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4o-mini-2024-07-18 - value: gpt-4o-mini-2024-07-18 - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4o - value: gpt-4o - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4o-2024-05-13 - value: gpt-4o-2024-05-13 - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4o-2024-08-06 - value: gpt-4o-2024-08-06 - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4-turbo - value: gpt-4-turbo - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4-turbo-2024-04-09 - value: gpt-4-turbo-2024-04-09 - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4-0125-preview - value: gpt-4-0125-preview - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4-1106-preview - value: gpt-4-1106-preview - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-4-vision-preview - value: gpt-4-vision-preview - show_on: - - variable: __model_type - value: llm - - label: - en_US: gpt-35-turbo-instruct - value: gpt-35-turbo-instruct - show_on: - - variable: __model_type - value: llm - - label: - en_US: text-embedding-ada-002 - value: text-embedding-ada-002 - show_on: - - variable: __model_type - value: text-embedding - - label: - en_US: text-embedding-3-small - value: text-embedding-3-small - show_on: - - variable: __model_type - value: text-embedding - - label: - en_US: text-embedding-3-large - value: text-embedding-3-large - show_on: - - variable: __model_type - value: text-embedding - - label: - en_US: whisper-1 - value: whisper-1 - show_on: - - variable: __model_type - value: speech2text - - label: - en_US: tts-1 - value: tts-1 - show_on: - - variable: __model_type - value: tts - - label: - en_US: tts-1-hd - value: tts-1-hd - show_on: - - variable: __model_type - value: tts - placeholder: - zh_Hans: 在此输入您的模型版本 - en_US: Enter your model version diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py deleted file mode 100644 index 1cd4823e13..0000000000 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ /dev/null @@ -1,764 +0,0 @@ -import copy -import json -import logging -from collections.abc import Generator, Sequence -from typing import Optional, Union, cast - -import tiktoken -from openai import AzureOpenAI, Stream -from openai.types import Completion -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall - -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageFunction, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI -from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS -from core.model_runtime.utils import helper - -logger = logging.getLogger(__name__) - - -class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): - def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - base_model_name = self._get_base_model_name(credentials) - ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) - - if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: - # chat model - return self._chat_generate( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - else: - # text completion model - return self._generate( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stop=stop, - stream=stream, - user=user, - ) - - def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - ) -> int: - base_model_name = self._get_base_model_name(credentials) - model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) - if not model_entity: - raise ValueError(f"Base Model Name {base_model_name} is invalid") - model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE) - - if model_mode == LLMMode.CHAT.value: - # chat model - return self._num_tokens_from_messages(credentials, prompt_messages, tools) - else: - # text completion model, do not support tool calling - content = prompt_messages[0].content - assert isinstance(content, str) - return self._num_tokens_from_string(credentials, content) - - def validate_credentials(self, model: str, credentials: dict) -> None: - if "openai_api_base" not in credentials: - raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - - if "openai_api_key" not in credentials: - raise CredentialsValidateFailedError("Azure OpenAI API key is required") - - if "base_model_name" not in credentials: - raise CredentialsValidateFailedError("Base Model Name is required") - - base_model_name = self._get_base_model_name(credentials) - ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) - - if not ai_model_entity: - raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') - - try: - client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - - if model.startswith("o1"): - client.chat.completions.create( - messages=[{"role": "user", "content": "ping"}], - model=model, - temperature=1, - max_completion_tokens=20, - stream=False, - ) - elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: - # chat model - client.chat.completions.create( - messages=[{"role": "user", "content": "ping"}], - model=model, - temperature=0, - max_tokens=20, - stream=False, - ) - else: - # text completion model - client.completions.create( - prompt="ping", - model=model, - temperature=0, - max_tokens=20, - stream=False, - ) - except Exception as ex: - raise CredentialsValidateFailedError(str(ex)) - - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - base_model_name = self._get_base_model_name(credentials) - ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) - return ai_model_entity.entity if ai_model_entity else None - - def _generate( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - - extra_model_kwargs = {} - - if stop: - extra_model_kwargs["stop"] = stop - - if user: - extra_model_kwargs["user"] = user - - # text completion model - response = client.completions.create( - prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs - ) - - if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) - - return self._handle_generate_response(model, credentials, response, prompt_messages) - - def _handle_generate_response( - self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] - ): - assistant_text = response.choices[0].text - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=assistant_text) - - # calculate num tokens - if response.usage: - # transform usage - prompt_tokens = response.usage.prompt_tokens - completion_tokens = response.usage.completion_tokens - else: - # calculate num tokens - content = prompt_messages[0].content - assert isinstance(content, str) - prompt_tokens = self._num_tokens_from_string(credentials, content) - completion_tokens = self._num_tokens_from_string(credentials, assistant_text) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - # transform response - result = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage, - system_fingerprint=response.system_fingerprint, - ) - - return result - - def _handle_generate_stream_response( - self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] - ) -> Generator: - full_text = "" - for chunk in response: - if len(chunk.choices) == 0: - continue - - delta = chunk.choices[0] - - if delta.finish_reason is None and (delta.text is None or delta.text == ""): - continue - - # transform assistant message to prompt message - text = delta.text or "" - assistant_prompt_message = AssistantPromptMessage(content=text) - - full_text += text - - if delta.finish_reason is not None: - # calculate num tokens - if chunk.usage: - # transform usage - prompt_tokens = chunk.usage.prompt_tokens - completion_tokens = chunk.usage.completion_tokens - else: - # calculate num tokens - content = prompt_messages[0].content - assert isinstance(content, str) - prompt_tokens = self._num_tokens_from_string(credentials, content) - completion_tokens = self._num_tokens_from_string(credentials, full_text) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - yield LLMResultChunk( - model=chunk.model, - prompt_messages=prompt_messages, - system_fingerprint=chunk.system_fingerprint, - delta=LLMResultChunkDelta( - index=delta.index, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage, - ), - ) - else: - yield LLMResultChunk( - model=chunk.model, - prompt_messages=prompt_messages, - system_fingerprint=chunk.system_fingerprint, - delta=LLMResultChunkDelta( - index=delta.index, - message=assistant_prompt_message, - ), - ) - - def _chat_generate( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - - response_format = model_parameters.get("response_format") - if response_format: - if response_format == "json_schema": - json_schema = model_parameters.get("json_schema") - if not json_schema: - raise ValueError("Must define JSON Schema when the response format is json_schema") - try: - schema = json.loads(json_schema) - except: - raise ValueError(f"not correct json_schema format: {json_schema}") - model_parameters.pop("json_schema") - model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} - else: - model_parameters["response_format"] = {"type": response_format} - - extra_model_kwargs = {} - - if tools: - extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - - if stop: - extra_model_kwargs["stop"] = stop - - if user: - extra_model_kwargs["user"] = user - - # clear illegal prompt messages - prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) - - block_as_stream = False - if model.startswith("o1"): - if stream: - block_as_stream = True - stream = False - - if "stream_options" in extra_model_kwargs: - del extra_model_kwargs["stream_options"] - - if "stop" in extra_model_kwargs: - del extra_model_kwargs["stop"] - - # chat model - response = client.chat.completions.create( - messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs, - ) - - if stream: - return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) - - block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - - if block_as_stream: - return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) - - return block_result - - def _handle_chat_block_as_stream_response( - self, - block_result: LLMResult, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None, - ) -> Generator[LLMResultChunk, None, None]: - """ - Handle llm chat response - - :param model: model name - :param credentials: credentials - :param response: response - :param prompt_messages: prompt messages - :param tools: tools for tool calling - :param stop: stop words - :return: llm response chunk generator - """ - text = block_result.message.content - text = cast(str, text) - - if stop: - text = self.enforce_stop_tokens(text, stop) - - yield LLMResultChunk( - model=block_result.model, - prompt_messages=prompt_messages, - system_fingerprint=block_result.system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=text), - finish_reason="stop", - usage=block_result.usage, - ), - ) - - def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Clear illegal prompt messages for OpenAI API - - :param model: model name - :param prompt_messages: prompt messages - :return: cleaned prompt messages - """ - checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"] - - if model in checklist: - # count how many user messages are there - user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)]) - if user_message_count > 1: - for prompt_message in prompt_messages: - if isinstance(prompt_message, UserPromptMessage): - if isinstance(prompt_message.content, list): - prompt_message.content = "\n".join( - [ - item.data - if item.type == PromptMessageContentType.TEXT - else "[IMAGE]" - if item.type == PromptMessageContentType.IMAGE - else "" - for item in prompt_message.content - ] - ) - - if model.startswith("o1"): - system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) - if system_message_count > 0: - new_prompt_messages = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message, SystemPromptMessage): - prompt_message = UserPromptMessage( - content=prompt_message.content, - name=prompt_message.name, - ) - - new_prompt_messages.append(prompt_message) - prompt_messages = new_prompt_messages - - return prompt_messages - - def _handle_chat_generate_response( - self, - model: str, - credentials: dict, - response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - ): - assistant_message = response.choices[0].message - assistant_message_tool_calls = assistant_message.tool_calls - - # extract tool calls from response - tool_calls = [] - self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls) - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) - - # calculate num tokens - if response.usage: - # transform usage - prompt_tokens = response.usage.prompt_tokens - completion_tokens = response.usage.completion_tokens - else: - # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) - completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message]) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - # transform response - result = LLMResult( - model=response.model or model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage, - system_fingerprint=response.system_fingerprint, - ) - - return result - - def _handle_chat_generate_stream_response( - self, - model: str, - credentials: dict, - response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - ): - index = 0 - full_assistant_content = "" - real_model = model - system_fingerprint = None - completion = "" - tool_calls = [] - for chunk in response: - if len(chunk.choices) == 0: - continue - - delta = chunk.choices[0] - # NOTE: For fix https://github.com/langgenius/dify/issues/5790 - if delta.delta is None: - continue - - # extract tool calls from response - self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls) - - # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter - if delta.finish_reason is None and not delta.delta.content: - continue - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) - - full_assistant_content += delta.delta.content or "" - - real_model = chunk.model - system_fingerprint = chunk.system_fingerprint - completion += delta.delta.content or "" - - yield LLMResultChunk( - model=real_model, - prompt_messages=prompt_messages, - system_fingerprint=system_fingerprint, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - ), - ) - - index += 1 - - # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) - - full_assistant_prompt_message = AssistantPromptMessage(content=completion) - completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - yield LLMResultChunk( - model=real_model, - prompt_messages=prompt_messages, - system_fingerprint=system_fingerprint, - delta=LLMResultChunkDelta( - index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage - ), - ) - - @staticmethod - def _update_tool_calls( - tool_calls: list[AssistantPromptMessage.ToolCall], - tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]], - ) -> None: - if tool_calls_response: - for response_tool_call in tool_calls_response: - if isinstance(response_tool_call, ChatCompletionMessageToolCall): - function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, arguments=response_tool_call.function.arguments - ) - - tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, type=response_tool_call.type, function=function - ) - tool_calls.append(tool_call) - elif isinstance(response_tool_call, ChoiceDeltaToolCall): - index = response_tool_call.index - if index < len(tool_calls): - tool_calls[index].id = response_tool_call.id or tool_calls[index].id - tool_calls[index].type = response_tool_call.type or tool_calls[index].type - if response_tool_call.function: - tool_calls[index].function.name = ( - response_tool_call.function.name or tool_calls[index].function.name - ) - tool_calls[index].function.arguments += response_tool_call.function.arguments or "" - else: - assert response_tool_call.id is not None - assert response_tool_call.type is not None - assert response_tool_call.function is not None - assert response_tool_call.function.name is not None - assert response_tool_call.function.arguments is not None - - function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, arguments=response_tool_call.function.arguments - ) - tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, type=response_tool_call.type, function=function - ) - tool_calls.append(tool_call) - - @staticmethod - def _convert_prompt_message_to_dict(message: PromptMessage): - if isinstance(message, UserPromptMessage): - message = cast(UserPromptMessage, message) - if isinstance(message.content, str): - message_dict = {"role": "user", "content": message.content} - else: - sub_messages = [] - assert message.content is not None - for message_content in message.content: - if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = {"type": "text", "text": message_content.data} - sub_messages.append(sub_message_dict) - elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) - sub_message_dict = { - "type": "image_url", - "image_url": {"url": message_content.data, "detail": message_content.detail.value}, - } - sub_messages.append(sub_message_dict) - message_dict = {"role": "user", "content": sub_messages} - elif isinstance(message, AssistantPromptMessage): - # message = cast(AssistantPromptMessage, message) - message_dict = {"role": "assistant", "content": message.content} - if message.tool_calls: - message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls] - elif isinstance(message, SystemPromptMessage): - message = cast(SystemPromptMessage, message) - message_dict = {"role": "system", "content": message.content} - elif isinstance(message, ToolPromptMessage): - message = cast(ToolPromptMessage, message) - message_dict = { - "role": "tool", - "name": message.name, - "content": message.content, - "tool_call_id": message.tool_call_id, - } - else: - raise ValueError(f"Got unknown type {message}") - - if message.name: - message_dict["name"] = message.name - - return message_dict - - def _num_tokens_from_string( - self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None - ) -> int: - try: - encoding = tiktoken.encoding_for_model(credentials["base_model_name"]) - except KeyError: - encoding = tiktoken.get_encoding("cl100k_base") - - num_tokens = len(encoding.encode(text)) - - if tools: - num_tokens += self._num_tokens_for_tools(encoding, tools) - - return num_tokens - - def _num_tokens_from_messages( - self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None - ) -> int: - """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - model = credentials["base_model_name"] - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") - model = "cl100k_base" - encoding = tiktoken.get_encoding(model) - - if model.startswith("gpt-35-turbo-0301"): - # every message follows {role/name}\n{content}\n - tokens_per_message = 4 - # if there's a name, the role is omitted - tokens_per_name = -1 - elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"): - tokens_per_message = 3 - tokens_per_name = 1 - else: - raise NotImplementedError( - f"get_num_tokens_from_messages() is not presently implemented " - f"for model {model}." - "See https://github.com/openai/openai-python/blob/main/chatml.md for " - "information on how messages are converted to tokens." - ) - num_tokens = 0 - messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] - for message in messages_dict: - num_tokens += tokens_per_message - for key, value in message.items(): - # Cast str(value) in case the message value is not a string - # This occurs with function messages - # TODO: The current token calculation method for the image type is not implemented, - # which need to download the image and then get the resolution for calculation, - # and will increase the request delay - if isinstance(value, list): - text = "" - for item in value: - if isinstance(item, dict) and item["type"] == "text": - text += item["text"] - - value = text - - if key == "tool_calls": - for tool_call in value: - assert isinstance(tool_call, dict) - for t_key, t_value in tool_call.items(): - num_tokens += len(encoding.encode(t_key)) - if t_key == "function": - for f_key, f_value in t_value.items(): - num_tokens += len(encoding.encode(f_key)) - num_tokens += len(encoding.encode(f_value)) - else: - num_tokens += len(encoding.encode(t_key)) - num_tokens += len(encoding.encode(t_value)) - else: - num_tokens += len(encoding.encode(str(value))) - - if key == "name": - num_tokens += tokens_per_name - - # every reply is primed with assistant - num_tokens += 3 - - if tools: - num_tokens += self._num_tokens_for_tools(encoding, tools) - - return num_tokens - - @staticmethod - def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int: - num_tokens = 0 - for tool in tools: - num_tokens += len(encoding.encode("type")) - num_tokens += len(encoding.encode("function")) - - # calculate num tokens for function object - num_tokens += len(encoding.encode("name")) - num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode("description")) - num_tokens += len(encoding.encode(tool.description)) - parameters = tool.parameters - num_tokens += len(encoding.encode("parameters")) - if "title" in parameters: - num_tokens += len(encoding.encode("title")) - num_tokens += len(encoding.encode(parameters["title"])) - num_tokens += len(encoding.encode("type")) - num_tokens += len(encoding.encode(parameters["type"])) - if "properties" in parameters: - num_tokens += len(encoding.encode("properties")) - for key, value in parameters["properties"].items(): - num_tokens += len(encoding.encode(key)) - for field_key, field_value in value.items(): - num_tokens += len(encoding.encode(field_key)) - if field_key == "enum": - for enum_field in field_value: - num_tokens += 3 - num_tokens += len(encoding.encode(enum_field)) - else: - num_tokens += len(encoding.encode(field_key)) - num_tokens += len(encoding.encode(str(field_value))) - if "required" in parameters: - num_tokens += len(encoding.encode("required")) - for required_field in parameters["required"]: - num_tokens += 3 - num_tokens += len(encoding.encode(required_field)) - - return num_tokens - - @staticmethod - def _get_ai_model_entity(base_model_name: str, model: str): - for ai_model_entity in LLM_BASE_MODELS: - if ai_model_entity.base_model_name == base_model_name: - ai_model_entity_copy = copy.deepcopy(ai_model_entity) - ai_model_entity_copy.entity.model = model - ai_model_entity_copy.entity.label.en_US = model - ai_model_entity_copy.entity.label.zh_Hans = model - return ai_model_entity_copy - - def _get_base_model_name(self, credentials: dict) -> str: - base_model_name = credentials.get("base_model_name") - if not base_model_name: - raise ValueError("Base Model Name is required") - return base_model_name diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py deleted file mode 100644 index b1b07a611b..0000000000 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ /dev/null @@ -1,450 +0,0 @@ -import base64 -import io -import json -import logging -from collections.abc import Generator -from typing import Optional, Union, cast - -import google.ai.generativelanguage as glm -import google.generativeai as genai -import requests -from google.api_core import exceptions -from google.generativeai.client import _ClientManager -from google.generativeai.types import ContentType, GenerateContentResponse -from google.generativeai.types.content_types import to_part -from PIL import Image - -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageTool, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - -logger = logging.getLogger(__name__) - -GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. -The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure -if you are not sure about the structure. - - -{{instructions}} - -""" # noqa: E501 - - -class GoogleLargeLanguageModel(LargeLanguageModel): - def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - """ - Invoke large language model - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :return: full response or stream response chunk generator result - """ - # invoke model - return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - - def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - ) -> int: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param tools: tools for tool calling - :return:md = genai.GenerativeModel(model) - """ - prompt = self._convert_messages_to_prompt(prompt_messages) - - return self._get_num_tokens_by_gpt2(prompt) - - def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: - """ - Format a list of messages into a full prompt for the Google model - - :param messages: List of PromptMessage to combine. - :return: Combined string with necessary human_prompt and ai_prompt tags. - """ - messages = messages.copy() # don't mutate the original list - - text = "".join(self._convert_one_message_to_text(message) for message in messages) - - return text.rstrip() - - def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: - """ - Convert tool messages to glm tools - - :param tools: tool messages - :return: glm tools - """ - function_declarations = [] - for tool in tools: - properties = {} - for key, value in tool.parameters.get("properties", {}).items(): - properties[key] = { - "type_": glm.Type.STRING, - "description": value.get("description", ""), - "enum": value.get("enum", []), - } - - if properties: - parameters = glm.Schema( - type=glm.Type.OBJECT, - properties=properties, - required=tool.parameters.get("required", []), - ) - else: - parameters = None - - function_declaration = glm.FunctionDeclaration( - name=tool.name, - parameters=parameters, - description=tool.description, - ) - function_declarations.append(function_declaration) - - return glm.Tool(function_declarations=function_declarations) - - def validate_credentials(self, model: str, credentials: dict) -> None: - """ - Validate model credentials - - :param model: model name - :param credentials: model credentials - :return: - """ - - try: - ping_message = SystemPromptMessage(content="ping") - self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) - - except Exception as ex: - raise CredentialsValidateFailedError(str(ex)) - - def _generate( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - """ - Invoke large language model - - :param model: model name - :param credentials: credentials kwargs - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :return: full response or stream response chunk generator result - """ - config_kwargs = model_parameters.copy() - config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) - - if stop: - config_kwargs["stop_sequences"] = stop - - google_model = genai.GenerativeModel(model_name=model) - - history = [] - - # hack for gemini-pro-vision, which currently does not support multi-turn chat - if model == "gemini-pro-vision": - last_msg = prompt_messages[-1] - content = self._format_message_to_glm_content(last_msg) - history.append(content) - else: - for msg in prompt_messages: # makes message roles strictly alternating - content = self._format_message_to_glm_content(msg) - if history and history[-1]["role"] == content["role"]: - history[-1]["parts"].extend(content["parts"]) - else: - history.append(content) - - # Create a new ClientManager with tenant's API key - new_client_manager = _ClientManager() - new_client_manager.configure(api_key=credentials["google_api_key"]) - new_custom_client = new_client_manager.make_client("generative") - - google_model._client = new_custom_client - - response = google_model.generate_content( - contents=history, - generation_config=genai.types.GenerationConfig(**config_kwargs), - stream=stream, - tools=self._convert_tools_to_glm_tool(tools) if tools else None, - request_options={"timeout": 600}, - ) - - if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) - - return self._handle_generate_response(model, credentials, response, prompt_messages) - - def _handle_generate_response( - self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] - ) -> LLMResult: - """ - Handle llm response - - :param model: model name - :param credentials: credentials - :param response: response - :param prompt_messages: prompt messages - :return: llm response - """ - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=response.text) - - # calculate num tokens - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - # transform response - result = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage, - ) - - return result - - def _handle_generate_stream_response( - self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] - ) -> Generator: - """ - Handle llm stream response - - :param model: model name - :param credentials: credentials - :param response: response - :param prompt_messages: prompt messages - :return: llm response chunk generator result - """ - index = -1 - for chunk in response: - for part in chunk.parts: - assistant_prompt_message = AssistantPromptMessage(content="") - - if part.text: - assistant_prompt_message.content += part.text - - if part.function_call: - assistant_prompt_message.tool_calls = [ - AssistantPromptMessage.ToolCall( - id=part.function_call.name, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=part.function_call.name, - arguments=json.dumps(dict(part.function_call.args.items())), - ), - ) - ] - - index += 1 - - if not response._done: - # transform assistant message to prompt message - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), - ) - else: - # calculate num tokens - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=str(chunk.candidates[0].finish_reason), - usage=usage, - ), - ) - - def _convert_one_message_to_text(self, message: PromptMessage) -> str: - """ - Convert a single message to a string. - - :param message: PromptMessage to convert. - :return: String representation of the message. - """ - human_prompt = "\n\nuser:" - ai_prompt = "\n\nmodel:" - - content = message.content - if isinstance(content, list): - content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) - - if isinstance(message, UserPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, AssistantPromptMessage): - message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage | ToolPromptMessage): - message_text = f"{human_prompt} {content}" - else: - raise ValueError(f"Got unknown type {message}") - - return message_text - - def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: - """ - Format a single message into glm.Content for Google API - - :param message: one PromptMessage - :return: glm Content representation of message - """ - if isinstance(message, UserPromptMessage): - glm_content = {"role": "user", "parts": []} - if isinstance(message.content, str): - glm_content["parts"].append(to_part(message.content)) - else: - for c in message.content: - if c.type == PromptMessageContentType.TEXT: - glm_content["parts"].append(to_part(c.data)) - elif c.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, c) - if message_content.data.startswith("data:"): - metadata, base64_data = c.data.split(",", 1) - mime_type = metadata.split(";", 1)[0].split(":")[1] - else: - # fetch image data from url - try: - image_content = requests.get(message_content.data).content - with Image.open(io.BytesIO(image_content)) as img: - mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode("utf-8") - except Exception as ex: - raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") - blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}} - glm_content["parts"].append(blob) - - return glm_content - elif isinstance(message, AssistantPromptMessage): - glm_content = {"role": "model", "parts": []} - if message.content: - glm_content["parts"].append(to_part(message.content)) - if message.tool_calls: - glm_content["parts"].append( - to_part( - glm.FunctionCall( - name=message.tool_calls[0].function.name, - args=json.loads(message.tool_calls[0].function.arguments), - ) - ) - ) - return glm_content - elif isinstance(message, SystemPromptMessage): - return {"role": "user", "parts": [to_part(message.content)]} - elif isinstance(message, ToolPromptMessage): - return { - "role": "function", - "parts": [ - glm.Part( - function_response=glm.FunctionResponse( - name=message.name, response={"response": message.content} - ) - ) - ], - } - else: - raise ValueError(f"Got unknown type {message}") - - @property - def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - """ - Map model invoke error to unified error - The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller - The value is the md = genai.GenerativeModel(model) error type thrown by the model, - which needs to be converted into a unified error type for the caller. - - :return: Invoke emd = genai.GenerativeModel(model) error mapping - """ - return { - InvokeConnectionError: [exceptions.RetryError], - InvokeServerUnavailableError: [ - exceptions.ServiceUnavailable, - exceptions.InternalServerError, - exceptions.BadGateway, - exceptions.GatewayTimeout, - exceptions.DeadlineExceeded, - ], - InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests], - InvokeAuthorizationError: [ - exceptions.Unauthenticated, - exceptions.PermissionDenied, - exceptions.Unauthenticated, - exceptions.Forbidden, - ], - InvokeBadRequestError: [ - exceptions.BadRequest, - exceptions.InvalidArgument, - exceptions.FailedPrecondition, - exceptions.OutOfRange, - exceptions.NotFound, - exceptions.MethodNotAllowed, - exceptions.Conflict, - exceptions.AlreadyExists, - exceptions.Aborted, - exceptions.LengthRequired, - exceptions.PreconditionFailed, - exceptions.RequestRangeNotSatisfiable, - exceptions.Cancelled, - ], - } diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py deleted file mode 100644 index 5c955c86d3..0000000000 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ /dev/null @@ -1,330 +0,0 @@ -import json -from collections.abc import Generator -from typing import Optional, Union, cast - -import requests - -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - PromptMessageContent, - PromptMessageContentType, - PromptMessageTool, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import ( - AIModelEntity, - FetchFrom, - ModelFeature, - ModelPropertyKey, - ModelType, - ParameterRule, - ParameterType, -) -from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel - - -class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - self._add_custom_parameters(credentials) - self._add_function_call(model, credentials) - user = user[:32] if user else None - # {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}} - if "response_format" in model_parameters: - model_parameters["response_format"] = {"type": model_parameters.get("response_format")} - return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - - def validate_credentials(self, model: str, credentials: dict) -> None: - self._add_custom_parameters(credentials) - super().validate_credentials(model, credentials) - - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - return AIModelEntity( - model=model, - label=I18nObject(en_US=model, zh_Hans=model), - model_type=ModelType.LLM, - features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] - if credentials.get("function_calling_type") == "tool_call" - else [], - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), - ModelPropertyKey.MODE: LLMMode.CHAT.value, - }, - parameter_rules=[ - ParameterRule( - name="temperature", - use_template="temperature", - label=I18nObject(en_US="Temperature", zh_Hans="温度"), - type=ParameterType.FLOAT, - ), - ParameterRule( - name="max_tokens", - use_template="max_tokens", - default=512, - min=1, - max=int(credentials.get("max_tokens", 4096)), - label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), - type=ParameterType.INT, - ), - ParameterRule( - name="top_p", - use_template="top_p", - label=I18nObject(en_US="Top P", zh_Hans="Top P"), - type=ParameterType.FLOAT, - ), - ], - ) - - def _add_custom_parameters(self, credentials: dict) -> None: - credentials["mode"] = "chat" - if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": - credentials["endpoint_url"] = "https://api.moonshot.cn/v1" - - def _add_function_call(self, model: str, credentials: dict) -> None: - model_schema = self.get_model_schema(model, credentials) - if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection( - model_schema.features or [] - ): - credentials["function_calling_type"] = "tool_call" - - def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: - """ - Convert PromptMessage to dict for OpenAI API format - """ - if isinstance(message, UserPromptMessage): - message = cast(UserPromptMessage, message) - if isinstance(message.content, str): - message_dict = {"role": "user", "content": message.content} - else: - sub_messages = [] - for message_content in message.content: - if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(PromptMessageContent, message_content) - sub_message_dict = {"type": "text", "text": message_content.data} - sub_messages.append(sub_message_dict) - elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) - sub_message_dict = { - "type": "image_url", - "image_url": {"url": message_content.data, "detail": message_content.detail.value}, - } - sub_messages.append(sub_message_dict) - message_dict = {"role": "user", "content": sub_messages} - elif isinstance(message, AssistantPromptMessage): - message = cast(AssistantPromptMessage, message) - message_dict = {"role": "assistant", "content": message.content} - if message.tool_calls: - message_dict["tool_calls"] = [] - for function_call in message.tool_calls: - message_dict["tool_calls"].append( - { - "id": function_call.id, - "type": function_call.type, - "function": { - "name": function_call.function.name, - "arguments": function_call.function.arguments, - }, - } - ) - elif isinstance(message, ToolPromptMessage): - message = cast(ToolPromptMessage, message) - message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} - elif isinstance(message, SystemPromptMessage): - message = cast(SystemPromptMessage, message) - message_dict = {"role": "system", "content": message.content} - else: - raise ValueError(f"Got unknown type {message}") - - if message.name: - message_dict["name"] = message.name - - return message_dict - - def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: - """ - Extract tool calls from response - - :param response_tool_calls: response tool calls - :return: list of tool calls - """ - tool_calls = [] - if response_tool_calls: - for response_tool_call in response_tool_calls: - function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"] - if response_tool_call.get("function", {}).get("name") - else "", - arguments=response_tool_call["function"]["arguments"] - if response_tool_call.get("function", {}).get("arguments") - else "", - ) - - tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call["id"] if response_tool_call.get("id") else "", - type=response_tool_call["type"] if response_tool_call.get("type") else "", - function=function, - ) - tool_calls.append(tool_call) - - return tool_calls - - def _handle_generate_stream_response( - self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] - ) -> Generator: - """ - Handle llm stream response - - :param model: model name - :param credentials: model credentials - :param response: streamed response - :param prompt_messages: prompt messages - :return: llm response chunk generator - """ - full_assistant_content = "" - chunk_index = 0 - - def create_final_llm_result_chunk( - index: int, message: AssistantPromptMessage, finish_reason: str - ) -> LLMResultChunk: - # calculate num tokens - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) - completion_tokens = self._num_tokens_from_string(model, full_assistant_content) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - return LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), - ) - - tools_calls: list[AssistantPromptMessage.ToolCall] = [] - finish_reason = "Unknown" - - def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): - def get_tool_call(tool_name: str): - if not tool_name: - return tools_calls[-1] - - tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) - if tool_call is None: - tool_call = AssistantPromptMessage.ToolCall( - id="", - type="", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), - ) - tools_calls.append(tool_call) - - return tool_call - - for new_tool_call in new_tool_calls: - # get tool call - tool_call = get_tool_call(new_tool_call.function.name) - # update tool call - if new_tool_call.id: - tool_call.id = new_tool_call.id - if new_tool_call.type: - tool_call.type = new_tool_call.type - if new_tool_call.function.name: - tool_call.function.name = new_tool_call.function.name - if new_tool_call.function.arguments: - tool_call.function.arguments += new_tool_call.function.arguments - - for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): - if chunk: - # ignore sse comments - if chunk.startswith(":"): - continue - decoded_chunk = chunk.strip().lstrip("data: ").lstrip() - chunk_json = None - try: - chunk_json = json.loads(decoded_chunk) - # stream ended - except json.JSONDecodeError as e: - yield create_final_llm_result_chunk( - index=chunk_index + 1, - message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered.", - ) - break - if not chunk_json or len(chunk_json["choices"]) == 0: - continue - - choice = chunk_json["choices"][0] - finish_reason = chunk_json["choices"][0].get("finish_reason") - chunk_index += 1 - - if "delta" in choice: - delta = choice["delta"] - delta_content = delta.get("content") - - assistant_message_tool_calls = delta.get("tool_calls", None) - # assistant_message_function_call = delta.delta.function_call - - # extract tool calls from response - if assistant_message_tool_calls: - tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) - increase_tool_call(tool_calls) - - if delta_content is None or delta_content == "": - continue - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else [] - ) - - full_assistant_content += delta_content - elif "text" in choice: - choice_text = choice.get("text", "") - if choice_text == "": - continue - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=choice_text) - full_assistant_content += choice_text - else: - continue - - # check payload indicator for completion - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=chunk_index, - message=assistant_prompt_message, - ), - ) - - chunk_index += 1 - - if tools_calls: - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=chunk_index, - message=AssistantPromptMessage(tool_calls=tools_calls, content=""), - ), - ) - - yield create_final_llm_result_chunk( - index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason - ) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py deleted file mode 100644 index e1342fe985..0000000000 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ /dev/null @@ -1,847 +0,0 @@ -import json -import logging -from collections.abc import Generator -from decimal import Decimal -from typing import Optional, Union, cast -from urllib.parse import urljoin - -import requests - -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - PromptMessageContent, - PromptMessageContentType, - PromptMessageFunction, - PromptMessageTool, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - FetchFrom, - ModelFeature, - ModelPropertyKey, - ModelType, - ParameterRule, - ParameterType, - PriceConfig, -) -from core.model_runtime.errors.invoke import InvokeError -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -from core.model_runtime.utils import helper - -logger = logging.getLogger(__name__) - - -class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): - """ - Model class for OpenAI large language model. - """ - - def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - """ - Invoke large language model - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :return: full response or stream response chunk generator result - """ - - # text completion model - return self._generate( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - - def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - ) -> int: - """ - Get number of tokens for given prompt messages - - :param model: - :param credentials: - :param prompt_messages: - :param tools: tools for tool calling - :return: - """ - return self._num_tokens_from_messages(model, prompt_messages, tools, credentials) - - def validate_credentials(self, model: str, credentials: dict) -> None: - """ - Validate model credentials using requests to ensure compatibility with all providers following - OpenAI's API standard. - - :param model: model name - :param credentials: model credentials - :return: - """ - try: - headers = {"Content-Type": "application/json"} - - api_key = credentials.get("api_key") - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - endpoint_url = credentials["endpoint_url"] - if not endpoint_url.endswith("/"): - endpoint_url += "/" - - # prepare the payload for a simple ping to the model - data = {"model": model, "max_tokens": 5} - - completion_type = LLMMode.value_of(credentials["mode"]) - - if completion_type is LLMMode.CHAT: - data["messages"] = [ - {"role": "user", "content": "ping"}, - ] - endpoint_url = urljoin(endpoint_url, "chat/completions") - elif completion_type is LLMMode.COMPLETION: - data["prompt"] = "ping" - endpoint_url = urljoin(endpoint_url, "completions") - else: - raise ValueError("Unsupported completion type for model configuration.") - - # send a post request to validate the credentials - response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300)) - - if response.status_code != 200: - raise CredentialsValidateFailedError( - f"Credentials validation failed with status code {response.status_code}" - ) - - try: - json_result = response.json() - except json.JSONDecodeError as e: - raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - - if completion_type is LLMMode.CHAT and json_result.get("object", "") == "": - json_result["object"] = "chat.completion" - elif completion_type is LLMMode.COMPLETION and json_result.get("object", "") == "": - json_result["object"] = "text_completion" - - if completion_type is LLMMode.CHAT and ( - "object" not in json_result or json_result["object"] != "chat.completion" - ): - raise CredentialsValidateFailedError( - "Credentials validation failed: invalid response object, must be 'chat.completion'" - ) - elif completion_type is LLMMode.COMPLETION and ( - "object" not in json_result or json_result["object"] != "text_completion" - ): - raise CredentialsValidateFailedError( - "Credentials validation failed: invalid response object, must be 'text_completion'" - ) - except CredentialsValidateFailedError: - raise - except Exception as ex: - raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: - """ - generate custom model entities from credentials - """ - features = [] - - function_calling_type = credentials.get("function_calling_type", "no_call") - if function_calling_type == "function_call": - features.append(ModelFeature.TOOL_CALL) - elif function_calling_type == "tool_call": - features.append(ModelFeature.MULTI_TOOL_CALL) - - stream_function_calling = credentials.get("stream_function_calling", "supported") - if stream_function_calling == "supported": - features.append(ModelFeature.STREAM_TOOL_CALL) - - vision_support = credentials.get("vision_support", "not_support") - if vision_support == "support": - features.append(ModelFeature.VISION) - - entity = AIModelEntity( - model=model, - label=I18nObject(en_US=model), - model_type=ModelType.LLM, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - features=features, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")), - ModelPropertyKey.MODE: credentials.get("mode"), - }, - parameter_rules=[ - ParameterRule( - name=DefaultParameterName.TEMPERATURE.value, - label=I18nObject(en_US="Temperature", zh_Hans="温度"), - help=I18nObject( - en_US="Kernel sampling threshold. Used to determine the randomness of the results." - "The higher the value, the stronger the randomness." - "The higher the possibility of getting different answers to the same question.", - zh_Hans="核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。", - ), - type=ParameterType.FLOAT, - default=float(credentials.get("temperature", 0.7)), - min=0, - max=2, - precision=2, - ), - ParameterRule( - name=DefaultParameterName.TOP_P.value, - label=I18nObject(en_US="Top P", zh_Hans="Top P"), - help=I18nObject( - en_US="The probability threshold of the nucleus sampling method during the generation process." - "The larger the value is, the higher the randomness of generation will be." - "The smaller the value is, the higher the certainty of generation will be.", - zh_Hans="生成过程中核采样方法概率阈值。取值越大,生成的随机性越高;取值越小,生成的确定性越高。", - ), - type=ParameterType.FLOAT, - default=float(credentials.get("top_p", 1)), - min=0, - max=1, - precision=2, - ), - ParameterRule( - name=DefaultParameterName.FREQUENCY_PENALTY.value, - label=I18nObject(en_US="Frequency Penalty", zh_Hans="频率惩罚"), - help=I18nObject( - en_US="For controlling the repetition rate of words used by the model." - "Increasing this can reduce the repetition of the same words in the model's output.", - zh_Hans="用于控制模型已使用字词的重复率。 提高此项可以降低模型在输出中重复相同字词的重复度。", - ), - type=ParameterType.FLOAT, - default=float(credentials.get("frequency_penalty", 0)), - min=-2, - max=2, - ), - ParameterRule( - name=DefaultParameterName.PRESENCE_PENALTY.value, - label=I18nObject(en_US="Presence Penalty", zh_Hans="存在惩罚"), - help=I18nObject( - en_US="Used to control the repetition rate when generating models." - "Increasing this can reduce the repetition rate of model generation.", - zh_Hans="用于控制模型生成时的重复度。提高此项可以降低模型生成的重复度。", - ), - type=ParameterType.FLOAT, - default=float(credentials.get("presence_penalty", 0)), - min=-2, - max=2, - ), - ParameterRule( - name=DefaultParameterName.MAX_TOKENS.value, - label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), - help=I18nObject( - en_US="Maximum length of tokens for the model response.", zh_Hans="模型回答的tokens的最大长度。" - ), - type=ParameterType.INT, - default=512, - min=1, - max=int(credentials.get("max_tokens_to_sample", 4096)), - ), - ], - pricing=PriceConfig( - input=Decimal(credentials.get("input_price", 0)), - output=Decimal(credentials.get("output_price", 0)), - unit=Decimal(credentials.get("unit", 0)), - currency=credentials.get("currency", "USD"), - ), - ) - - if credentials["mode"] == "chat": - entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif credentials["mode"] == "completion": - entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value - else: - raise ValueError(f"Unknown completion type {credentials['completion_type']}") - - return entity - - # validate_credentials method has been rewritten to use the requests library for compatibility with all providers - # following OpenAI's API standard. - def _generate( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - """ - Invoke llm completion model - - :param model: model name - :param credentials: credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :return: full response or stream response chunk generator result - """ - headers = { - "Content-Type": "application/json", - "Accept-Charset": "utf-8", - } - extra_headers = credentials.get("extra_headers") - if extra_headers is not None: - headers = { - **headers, - **extra_headers, - } - - api_key = credentials.get("api_key") - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - endpoint_url = credentials["endpoint_url"] - if not endpoint_url.endswith("/"): - endpoint_url += "/" - - data = {"model": model, "stream": stream, **model_parameters} - - completion_type = LLMMode.value_of(credentials["mode"]) - - if completion_type is LLMMode.CHAT: - endpoint_url = urljoin(endpoint_url, "chat/completions") - data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] - elif completion_type is LLMMode.COMPLETION: - endpoint_url = urljoin(endpoint_url, "completions") - data["prompt"] = prompt_messages[0].content - else: - raise ValueError("Unsupported completion type for model configuration.") - - # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get("function_calling_type", "no_call") - formatted_tools = [] - if tools: - if function_calling_type == "function_call": - data["functions"] = [ - {"name": tool.name, "description": tool.description, "parameters": tool.parameters} - for tool in tools - ] - elif function_calling_type == "tool_call": - data["tool_choice"] = "auto" - - for tool in tools: - formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool))) - - data["tools"] = formatted_tools - - if stop: - data["stop"] = stop - - if user: - data["user"] = user - - response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) - - if response.encoding is None or response.encoding == "ISO-8859-1": - response.encoding = "utf-8" - - if response.status_code != 200: - raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") - - if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) - - return self._handle_generate_response(model, credentials, response, prompt_messages) - - def _handle_generate_stream_response( - self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] - ) -> Generator: - """ - Handle llm stream response - - :param model: model name - :param credentials: model credentials - :param response: streamed response - :param prompt_messages: prompt messages - :return: llm response chunk generator - """ - full_assistant_content = "" - chunk_index = 0 - - def create_final_llm_result_chunk( - id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict - ) -> LLMResultChunk: - # calculate num tokens - prompt_tokens = usage and usage.get("prompt_tokens") - if prompt_tokens is None: - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) - completion_tokens = usage and usage.get("completion_tokens") - if completion_tokens is None: - completion_tokens = self._num_tokens_from_string(model, full_assistant_content) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - return LLMResultChunk( - id=id, - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), - ) - - # delimiter for stream response, need unicode_escape - import codecs - - delimiter = credentials.get("stream_mode_delimiter", "\n\n") - delimiter = codecs.decode(delimiter, "unicode_escape") - - tools_calls: list[AssistantPromptMessage.ToolCall] = [] - - def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): - def get_tool_call(tool_call_id: str): - if not tool_call_id: - return tools_calls[-1] - - tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None) - if tool_call is None: - tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), - ) - tools_calls.append(tool_call) - - return tool_call - - for new_tool_call in new_tool_calls: - # get tool call - tool_call = get_tool_call(new_tool_call.function.name) - # update tool call - if new_tool_call.id: - tool_call.id = new_tool_call.id - if new_tool_call.type: - tool_call.type = new_tool_call.type - if new_tool_call.function.name: - tool_call.function.name = new_tool_call.function.name - if new_tool_call.function.arguments: - tool_call.function.arguments += new_tool_call.function.arguments - - finish_reason = None # The default value of finish_reason is None - message_id, usage = None, None - for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): - chunk = chunk.strip() - if chunk: - # ignore sse comments - if chunk.startswith(":"): - continue - decoded_chunk = chunk.strip().lstrip("data: ").lstrip() - if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]" - continue - - try: - chunk_json: dict = json.loads(decoded_chunk) - # stream ended - except json.JSONDecodeError as e: - yield create_final_llm_result_chunk( - id=message_id, - index=chunk_index + 1, - message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered.", - usage=usage, - ) - break - if chunk_json: - if u := chunk_json.get("usage"): - usage = u - if not chunk_json or len(chunk_json["choices"]) == 0: - continue - - choice = chunk_json["choices"][0] - finish_reason = chunk_json["choices"][0].get("finish_reason") - message_id = chunk_json.get("id") - chunk_index += 1 - - if "delta" in choice: - delta = choice["delta"] - delta_content = delta.get("content") - - assistant_message_tool_calls = None - - if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call": - assistant_message_tool_calls = delta.get("tool_calls", None) - elif ( - "function_call" in delta - and credentials.get("function_calling_type", "no_call") == "function_call" - ): - assistant_message_tool_calls = [ - {"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})} - ] - - # assistant_message_function_call = delta.delta.function_call - - # extract tool calls from response - if assistant_message_tool_calls: - tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) - increase_tool_call(tool_calls) - - if delta_content is None or delta_content == "": - continue - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta_content, - ) - - # reset tool calls - tool_calls = [] - full_assistant_content += delta_content - elif "text" in choice: - choice_text = choice.get("text", "") - if choice_text == "": - continue - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=choice_text) - full_assistant_content += choice_text - else: - continue - - yield LLMResultChunk( - id=message_id, - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=chunk_index, - message=assistant_prompt_message, - ), - ) - - chunk_index += 1 - - if tools_calls: - yield LLMResultChunk( - id=message_id, - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=chunk_index, - message=AssistantPromptMessage(tool_calls=tools_calls, content=""), - ), - ) - - yield create_final_llm_result_chunk( - id=message_id, - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason, - usage=usage, - ) - - def _handle_generate_response( - self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] - ) -> LLMResult: - response_json: dict = response.json() - - completion_type = LLMMode.value_of(credentials["mode"]) - - output = response_json["choices"][0] - message_id = response_json.get("id") - - response_content = "" - tool_calls = None - function_calling_type = credentials.get("function_calling_type", "no_call") - if completion_type is LLMMode.CHAT: - response_content = output.get("message", {})["content"] - if function_calling_type == "tool_call": - tool_calls = output.get("message", {}).get("tool_calls") - elif function_calling_type == "function_call": - tool_calls = output.get("message", {}).get("function_call") - - elif completion_type is LLMMode.COMPLETION: - response_content = output["text"] - - assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[]) - - if tool_calls: - if function_calling_type == "tool_call": - assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) - elif function_calling_type == "function_call": - assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)] - - usage = response_json.get("usage") - if usage: - # transform usage - prompt_tokens = usage["prompt_tokens"] - completion_tokens = usage["completion_tokens"] - else: - # calculate num tokens - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) - completion_tokens = self._num_tokens_from_string(model, assistant_message.content) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - # transform response - result = LLMResult( - id=message_id, - model=response_json["model"], - prompt_messages=prompt_messages, - message=assistant_message, - usage=usage, - ) - - return result - - def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: - """ - Convert PromptMessage to dict for OpenAI API format - """ - if isinstance(message, UserPromptMessage): - message = cast(UserPromptMessage, message) - if isinstance(message.content, str): - message_dict = {"role": "user", "content": message.content} - else: - sub_messages = [] - for message_content in message.content: - if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(PromptMessageContent, message_content) - sub_message_dict = {"type": "text", "text": message_content.data} - sub_messages.append(sub_message_dict) - elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) - sub_message_dict = { - "type": "image_url", - "image_url": {"url": message_content.data, "detail": message_content.detail.value}, - } - sub_messages.append(sub_message_dict) - - message_dict = {"role": "user", "content": sub_messages} - elif isinstance(message, AssistantPromptMessage): - message = cast(AssistantPromptMessage, message) - message_dict = {"role": "assistant", "content": message.content} - if message.tool_calls: - function_calling_type = credentials.get("function_calling_type", "no_call") - if function_calling_type == "tool_call": - message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls] - elif function_calling_type == "function_call": - function_call = message.tool_calls[0] - message_dict["function_call"] = { - "name": function_call.function.name, - "arguments": function_call.function.arguments, - } - elif isinstance(message, SystemPromptMessage): - message = cast(SystemPromptMessage, message) - message_dict = {"role": "system", "content": message.content} - elif isinstance(message, ToolPromptMessage): - message = cast(ToolPromptMessage, message) - function_calling_type = credentials.get("function_calling_type", "no_call") - if function_calling_type == "tool_call": - message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} - elif function_calling_type == "function_call": - message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} - else: - raise ValueError(f"Got unknown type {message}") - - if message.name and message_dict.get("role", "") != "tool": - message_dict["name"] = message.name - - return message_dict - - def _num_tokens_from_string( - self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None - ) -> int: - """ - Approximate num tokens for model with gpt2 tokenizer. - - :param model: model name - :param text: prompt text - :param tools: tools for tool calling - :return: number of tokens - """ - if isinstance(text, str): - full_text = text - else: - full_text = "" - for message_content in text: - if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(PromptMessageContent, message_content) - full_text += message_content.data - - num_tokens = self._get_num_tokens_by_gpt2(full_text) - - if tools: - num_tokens += self._num_tokens_for_tools(tools) - - return num_tokens - - def _num_tokens_from_messages( - self, - model: str, - messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - credentials: Optional[dict] = None, - ) -> int: - """ - Approximate num tokens with GPT2 tokenizer. - """ - - tokens_per_message = 3 - tokens_per_name = 1 - - num_tokens = 0 - messages_dict = [self._convert_prompt_message_to_dict(m, credentials) for m in messages] - for message in messages_dict: - num_tokens += tokens_per_message - for key, value in message.items(): - # Cast str(value) in case the message value is not a string - # This occurs with function messages - # TODO: The current token calculation method for the image type is not implemented, - # which need to download the image and then get the resolution for calculation, - # and will increase the request delay - if isinstance(value, list): - text = "" - for item in value: - if isinstance(item, dict) and item["type"] == "text": - text += item["text"] - - value = text - - if key == "tool_calls": - for tool_call in value: - for t_key, t_value in tool_call.items(): - num_tokens += self._get_num_tokens_by_gpt2(t_key) - if t_key == "function": - for f_key, f_value in t_value.items(): - num_tokens += self._get_num_tokens_by_gpt2(f_key) - num_tokens += self._get_num_tokens_by_gpt2(f_value) - else: - num_tokens += self._get_num_tokens_by_gpt2(t_key) - num_tokens += self._get_num_tokens_by_gpt2(t_value) - else: - num_tokens += self._get_num_tokens_by_gpt2(str(value)) - - if key == "name": - num_tokens += tokens_per_name - - # every reply is primed with assistant - num_tokens += 3 - - if tools: - num_tokens += self._num_tokens_for_tools(tools) - - return num_tokens - - def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: - """ - Calculate num tokens for tool calling with tiktoken package. - - :param tools: tools for tool calling - :return: number of tokens - """ - num_tokens = 0 - for tool in tools: - num_tokens += self._get_num_tokens_by_gpt2("type") - num_tokens += self._get_num_tokens_by_gpt2("function") - num_tokens += self._get_num_tokens_by_gpt2("function") - - # calculate num tokens for function object - num_tokens += self._get_num_tokens_by_gpt2("name") - num_tokens += self._get_num_tokens_by_gpt2(tool.name) - num_tokens += self._get_num_tokens_by_gpt2("description") - num_tokens += self._get_num_tokens_by_gpt2(tool.description) - parameters = tool.parameters - num_tokens += self._get_num_tokens_by_gpt2("parameters") - if "title" in parameters: - num_tokens += self._get_num_tokens_by_gpt2("title") - num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title")) - num_tokens += self._get_num_tokens_by_gpt2("type") - num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type")) - if "properties" in parameters: - num_tokens += self._get_num_tokens_by_gpt2("properties") - for key, value in parameters.get("properties").items(): - num_tokens += self._get_num_tokens_by_gpt2(key) - for field_key, field_value in value.items(): - num_tokens += self._get_num_tokens_by_gpt2(field_key) - if field_key == "enum": - for enum_field in field_value: - num_tokens += 3 - num_tokens += self._get_num_tokens_by_gpt2(enum_field) - else: - num_tokens += self._get_num_tokens_by_gpt2(field_key) - num_tokens += self._get_num_tokens_by_gpt2(str(field_value)) - if "required" in parameters: - num_tokens += self._get_num_tokens_by_gpt2("required") - for required_field in parameters["required"]: - num_tokens += 3 - num_tokens += self._get_num_tokens_by_gpt2(required_field) - - return num_tokens - - def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: - """ - Extract tool calls from response - - :param response_tool_calls: response tool calls - :return: list of tool calls - """ - tool_calls = [] - if response_tool_calls: - for response_tool_call in response_tool_calls: - function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.get("function", {}).get("name", ""), - arguments=response_tool_call.get("function", {}).get("arguments", ""), - ) - - tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.get("id", ""), type=response_tool_call.get("type", ""), function=function - ) - tool_calls.append(tool_call) - - return tool_calls - - def _extract_response_function_call(self, response_function_call) -> AssistantPromptMessage.ToolCall: - """ - Extract function call from response - - :param response_function_call: response function call - :return: tool call - """ - tool_call = None - if response_function_call: - function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.get("name", ""), arguments=response_function_call.get("arguments", "") - ) - - tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.get("id", ""), type="function", function=function - ) - - return tool_call diff --git a/api/core/model_runtime/model_providers/vessl_ai/__init__.py b/api/core/model_runtime/model_providers/vessl_ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png new file mode 100644 index 0000000000..18ba350fa0 Binary files /dev/null and b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg new file mode 100644 index 0000000000..242f4e82b2 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py b/api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py new file mode 100644 index 0000000000..034c066ab5 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py @@ -0,0 +1,83 @@ +from decimal import Decimal + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + features = [] + + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties={ + ModelPropertyKey.MODE: credentials.get("mode"), + }, + parameter_rules=[ + ParameterRule( + name=DefaultParameterName.TEMPERATURE.value, + label=I18nObject(en_US="Temperature"), + type=ParameterType.FLOAT, + default=float(credentials.get("temperature", 0.7)), + min=0, + max=2, + precision=2, + ), + ParameterRule( + name=DefaultParameterName.TOP_P.value, + label=I18nObject(en_US="Top P"), + type=ParameterType.FLOAT, + default=float(credentials.get("top_p", 1)), + min=0, + max=1, + precision=2, + ), + ParameterRule( + name=DefaultParameterName.TOP_K.value, + label=I18nObject(en_US="Top K"), + type=ParameterType.INT, + default=int(credentials.get("top_k", 50)), + min=-2147483647, + max=2147483647, + precision=0, + ), + ParameterRule( + name=DefaultParameterName.MAX_TOKENS.value, + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + default=512, + min=1, + max=int(credentials.get("max_tokens_to_sample", 4096)), + ), + ], + pricing=PriceConfig( + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), + ) + + if credentials["mode"] == "chat": + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value + elif credentials["mode"] == "completion": + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value + else: + raise ValueError(f"Unknown completion type {credentials['completion_type']}") + + return entity diff --git a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py new file mode 100644 index 0000000000..7a987c6710 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py @@ -0,0 +1,10 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VesslAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml new file mode 100644 index 0000000000..6052756cae --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml @@ -0,0 +1,56 @@ +provider: vessl_ai +label: + en_US: vessl_ai +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.png +background: "#F1EFED" +help: + title: + en_US: How to deploy VESSL AI LLM Model Endpoint + url: + en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment +supported_model_types: + - llm +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + placeholder: + en_US: Enter your model name + credential_form_schemas: + - variable: endpoint_url + label: + en_US: endpoint url + type: text-input + required: true + placeholder: + en_US: Enter the url of your endpoint url + - variable: api_key + required: true + label: + en_US: API Key + type: secret-input + placeholder: + en_US: Enter your VESSL AI secret key + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + - value: chat + label: + en_US: Chat diff --git a/api/core/plugin/manager/model.py b/api/core/plugin/manager/model.py index fb58c4bb8d..2081dcc298 100644 --- a/api/core/plugin/manager/model.py +++ b/api/core/plugin/manager/model.py @@ -413,7 +413,7 @@ class PluginModelManager(BasePluginManager): """ response = self._request_with_plugin_daemon_response_stream( method="POST", - path=f"plugin/{tenant_id}/dispatch/model/voices", + path=f"plugin/{tenant_id}/dispatch/tts/model/voices", type=PluginVoicesResponse, data=jsonable_encoder( { @@ -434,8 +434,10 @@ class PluginModelManager(BasePluginManager): ) for resp in response: + voices = [] for voice in resp.voices: - return [{"name": voice.name, "value": voice.value}] + voices.append({"name": voice.name, "value": voice.value}) + return voices return [] diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 3affbd2d0a..57af05861c 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -34,6 +34,8 @@ class RetrievalService: reranking_mode: Optional[str] = "reranking_model", weights: Optional[dict] = None, ): + if not query: + return [] dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 1d4bfef76d..eb78e8aa69 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -3,11 +3,13 @@ import time import uuid from typing import Any +import numpy as np from pydantic import BaseModel, model_validator from pymochow import MochowClient from pymochow.auth.bce_credentials import BceCredentials from pymochow.configuration import Configuration -from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState +from pymochow.exception import ServerError +from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row @@ -116,6 +118,7 @@ class BaiduVector(BaseVector): self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'") def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector] anns = AnnSearch( vector_field=self.field_vector, vector_floats=query_vector, @@ -149,7 +152,13 @@ class BaiduVector(BaseVector): return docs def delete(self) -> None: - self._db.drop_table(table_name=self._collection_name) + try: + self._db.drop_table(table_name=self._collection_name) + except ServerError as e: + if e.code == ServerErrCode.TABLE_NOT_EXIST: + pass + else: + raise def _init_client(self, config) -> MochowClient: config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint) @@ -166,7 +175,14 @@ class BaiduVector(BaseVector): if exists: return self._client.database(self._client_config.database) else: - return self._client.create_database(database_name=self._client_config.database) + try: + self._client.create_database(database_name=self._client_config.database) + except ServerError as e: + if e.code == ServerErrCode.DB_ALREADY_EXIST: + pass + else: + raise + return def _table_existed(self) -> bool: tables = self._db.list_table() @@ -175,7 +191,7 @@ class BaiduVector(BaseVector): def _create_table(self, dimension: int) -> None: # Try to grab distributed lock and create table lock_name = "vector_indexing_lock_{}".format(self._collection_name) - with redis_client.lock(lock_name, timeout=20): + with redis_client.lock(lock_name, timeout=60): table_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(table_exist_cache_key): return @@ -238,15 +254,14 @@ class BaiduVector(BaseVector): description="Table for Dify", ) + # Wait for table created + while True: + time.sleep(1) + table = self._db.describe_table(self._collection_name) + if table.state == TableState.NORMAL: + break redis_client.set(table_exist_cache_key, 1, ex=3600) - # Wait for table created - while True: - time.sleep(1) - table = self._db.describe_table(self._collection_name) - if table.state == TableState.NORMAL: - break - class BaiduVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector: diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 0cd2a46460..a6f3ad7fef 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -37,7 +37,7 @@ class TidbService: } spending_limit = { - "monthly": 100, + "monthly": dify_config.TIDB_SPEND_LIMIT, } password = str(uuid.uuid4()).replace("-", "")[:16] display_name = str(uuid.uuid4()).replace("-", "")[:16] diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 40ebf0befd..fc82b2080b 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -27,18 +27,17 @@ class RerankModelRunner(BaseRerankRunner): :return: """ docs = [] - doc_id = [] + doc_id = set() unique_documents = [] - dify_documents = [item for item in documents if item.provider == "dify"] - external_documents = [item for item in documents if item.provider == "external"] - for document in dify_documents: - if document.metadata["doc_id"] not in doc_id: - doc_id.append(document.metadata["doc_id"]) + for document in documents: + if document.provider == "dify" and document.metadata["doc_id"] not in doc_id: + doc_id.add(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) - for document in external_documents: - docs.append(document.page_content) - unique_documents.append(document) + elif document.provider == "external": + if document not in unique_documents: + docs.append(document.page_content) + unique_documents.append(document) documents = unique_documents diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index a6b1ebc159..94a0c783e1 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -116,10 +116,8 @@ class ToolInvokeMessage(BaseModel): class VariableMessage(BaseModel): variable_name: str = Field(..., description="The name of the variable") - variable_value: str = Field(..., - description="The value of the variable") - stream: bool = Field( - default=False, description="Whether the variable is streamed") + variable_value: str = Field(..., description="The value of the variable") + stream: bool = Field(default=False, description="Whether the variable is streamed") @field_validator("variable_value", mode="before") @classmethod @@ -133,8 +131,7 @@ class ToolInvokeMessage(BaseModel): # if stream is true, the value must be a string if values.get("stream"): if not isinstance(value, str): - raise ValueError( - "When 'stream' is True, 'variable_value' must be a string.") + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") return value @@ -271,8 +268,7 @@ class ToolParameter(BaseModel): return str(value) except Exception: - raise ValueError( - f"The tool parameter value {value} is not in correct type.") + raise ValueError(f"The tool parameter value {value} is not in correct type of {self.as_normal_type()}.") class ToolParameterForm(Enum): SCHEMA = "schema" # should be set while adding tool @@ -280,17 +276,12 @@ class ToolParameter(BaseModel): LLM = "llm" # will be set by LLM name: str = Field(..., description="The name of the parameter") - label: I18nObject = Field(..., - description="The label presented to the user") - human_description: Optional[I18nObject] = Field( - default=None, description="The description presented to the user") - placeholder: Optional[I18nObject] = Field( - default=None, description="The placeholder presented to the user") - type: ToolParameterType = Field(..., - description="The type of the parameter") + label: I18nObject = Field(..., description="The label presented to the user") + human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") + placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user") + type: ToolParameterType = Field(..., description="The type of the parameter") scope: AppSelectorScope | ModelConfigScope | None = None - form: ToolParameterForm = Field(..., - description="The form of the parameter, schema/form/llm") + form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: Optional[str] = None required: Optional[bool] = False default: Optional[Union[float, int, str]] = None @@ -346,8 +337,7 @@ class ToolParameter(BaseModel): class ToolProviderIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") - description: I18nObject = Field(..., - description="The description of the tool") + description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") label: I18nObject = Field(..., description="The label of the tool") tags: Optional[list[ToolLabelEnum]] = Field( @@ -365,8 +355,7 @@ class ToolIdentity(BaseModel): class ToolDescription(BaseModel): - human: I18nObject = Field(..., - description="The description presented to the user") + human: I18nObject = Field(..., description="The description presented to the user") llm: str = Field(..., description="The description presented to the LLM") @@ -375,8 +364,7 @@ class ToolEntity(BaseModel): parameters: list[ToolParameter] = Field(default_factory=list) description: Optional[ToolDescription] = None output_schema: Optional[dict] = None - has_runtime_parameters: bool = Field( - default=False, description="Whether the tool has runtime parameters") + has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -403,10 +391,8 @@ class WorkflowToolParameterConfiguration(BaseModel): """ name: str = Field(..., description="The name of the parameter") - description: str = Field(..., - description="The description of the parameter") - form: ToolParameter.ToolParameterForm = Field( - ..., description="The form of the parameter") + description: str = Field(..., description="The description of the parameter") + form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") class ToolInvokeMeta(BaseModel): @@ -414,8 +400,7 @@ class ToolInvokeMeta(BaseModel): Tool invoke meta """ - time_cost: float = Field(..., - description="The time cost of the tool invoke") + time_cost: float = Field(..., description="The time cost of the tool invoke") error: Optional[str] = None tool_config: Optional[dict] = None @@ -474,5 +459,4 @@ class ToolProviderID: if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value): raise ValueError("Invalid plugin id") - self.organization, self.plugin_name, self.provider_name = value.split( - "/") + self.organization, self.plugin_name, self.provider_name = value.split("/") diff --git a/api/core/tools/provider/builtin/aliyuque/tools/base.py b/api/core/tools/provider/builtin/aliyuque/tools/base.py deleted file mode 100644 index edfb9fea8e..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/base.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Any - -import requests - - -class AliYuqueTool: - # yuque service url - server_url = "https://www.yuque.com" - - @staticmethod - def auth(token): - session = requests.Session() - session.headers.update({"Accept": "application/json", "X-Auth-Token": token}) - login = session.request("GET", AliYuqueTool.server_url + "/api/v2/user") - login.raise_for_status() - resp = login.json() - return resp - - def request(self, method: str, token, tool_parameters: dict[str, Any], path: str) -> str: - if not token: - raise Exception("token is required") - session = requests.Session() - session.headers.update({"accept": "application/json", "X-Auth-Token": token}) - new_params = {**tool_parameters} - - replacements = {k: v for k, v in new_params.items() if f"{{{k}}}" in path} - - for key, value in replacements.items(): - path = path.replace(f"{{{key}}}", str(value)) - del new_params[key] - - if method.upper() in {"POST", "PUT"}: - session.headers.update( - { - "Content-Type": "application/json", - } - ) - response = session.request(method.upper(), self.server_url + path, json=new_params) - else: - response = session.request(method, self.server_url + path, params=new_params) - response.raise_for_status() - return response.text diff --git a/api/core/tools/provider/builtin/aliyuque/tools/create_document.py b/api/core/tools/provider/builtin/aliyuque/tools/create_document.py deleted file mode 100644 index 01080fd1d5..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/create_document.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Any, Union - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool -from core.tools.tool.builtin_tool import BuiltinTool - - -class AliYuqueCreateDocumentTool(AliYuqueTool, BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - token = self.runtime.credentials.get("token", None) - if not token: - raise Exception("token is required") - return self.create_text_message(self.request("POST", token, tool_parameters, "/api/v2/repos/{book_id}/docs")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml deleted file mode 100644 index 6ac8ae6696..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml +++ /dev/null @@ -1,99 +0,0 @@ -identity: - name: aliyuque_create_document - author: 佐井 - label: - en_US: Create Document - zh_Hans: 创建文档 - icon: icon.svg -description: - human: - en_US: Creates a new document within a knowledge base without automatic addition to the table of contents. Requires a subsequent call to the "knowledge base directory update API". Supports setting visibility, format, and content. # 接口英文描述 - zh_Hans: 在知识库中创建新文档,但不会自动加入目录,需额外调用“知识库目录更新接口”。允许设置公开性、格式及正文内容。 - llm: Creates docs in a KB. - -parameters: - - name: book_id - type: string - required: true - form: llm - label: - en_US: Knowledge Base ID - zh_Hans: 知识库ID - human_description: - en_US: The unique identifier of the knowledge base where the document will be created. - zh_Hans: 文档将被创建的知识库的唯一标识。 - llm_description: ID of the target knowledge base. - - - name: title - type: string - required: false - form: llm - label: - en_US: Title - zh_Hans: 标题 - human_description: - en_US: The title of the document, defaults to 'Untitled' if not provided. - zh_Hans: 文档标题,默认为'无标题'如未提供。 - llm_description: Title of the document, defaults to 'Untitled'. - - - name: public - type: select - required: false - form: llm - options: - - value: 0 - label: - en_US: Private - zh_Hans: 私密 - - value: 1 - label: - en_US: Public - zh_Hans: 公开 - - value: 2 - label: - en_US: Enterprise-only - zh_Hans: 企业内公开 - label: - en_US: Visibility - zh_Hans: 公开性 - human_description: - en_US: Document visibility (0 Private, 1 Public, 2 Enterprise-only). - zh_Hans: 文档可见性(0 私密, 1 公开, 2 企业内公开)。 - llm_description: Doc visibility options, 0-private, 1-public, 2-enterprise. - - - name: format - type: select - required: false - form: llm - options: - - value: markdown - label: - en_US: markdown - zh_Hans: markdown - - value: html - label: - en_US: html - zh_Hans: html - - value: lake - label: - en_US: lake - zh_Hans: lake - label: - en_US: Content Format - zh_Hans: 内容格式 - human_description: - en_US: Format of the document content (markdown, HTML, Lake). - zh_Hans: 文档内容格式(markdown, HTML, Lake)。 - llm_description: Content format choices, markdown, HTML, Lake. - - - name: body - type: string - required: true - form: llm - label: - en_US: Body Content - zh_Hans: 正文内容 - human_description: - en_US: The actual content of the document. - zh_Hans: 文档的实际内容。 - llm_description: Content of the document. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py deleted file mode 100644 index 84237cec30..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Any, Union - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool -from core.tools.tool.builtin_tool import BuiltinTool - - -class AliYuqueDeleteDocumentTool(AliYuqueTool, BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - token = self.runtime.credentials.get("token", None) - if not token: - raise Exception("token is required") - return self.create_text_message( - self.request("DELETE", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") - ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml deleted file mode 100644 index dddd62d304..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml +++ /dev/null @@ -1,37 +0,0 @@ -identity: - name: aliyuque_delete_document - author: 佐井 - label: - en_US: Delete Document - zh_Hans: 删除文档 - icon: icon.svg -description: - human: - en_US: Delete Document - zh_Hans: 根据id删除文档 - llm: Delete document. - -parameters: - - name: book_id - type: string - required: true - form: llm - label: - en_US: Knowledge Base ID - zh_Hans: 知识库ID - human_description: - en_US: The unique identifier of the knowledge base where the document will be created. - zh_Hans: 文档将被创建的知识库的唯一标识。 - llm_description: ID of the target knowledge base. - - - name: id - type: string - required: true - form: llm - label: - en_US: Document ID or Path - zh_Hans: 文档 ID or 路径 - human_description: - en_US: Document ID or path. - zh_Hans: 文档 ID or 路径。 - llm_description: Document ID or path. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py deleted file mode 100644 index c23d30059a..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Any, Union - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool -from core.tools.tool.builtin_tool import BuiltinTool - - -class AliYuqueDescribeBookIndexPageTool(AliYuqueTool, BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - token = self.runtime.credentials.get("token", None) - if not token: - raise Exception("token is required") - return self.create_text_message( - self.request("GET", token, tool_parameters, "/api/v2/repos/{group_login}/{book_slug}/index_page") - ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py deleted file mode 100644 index 36f8c10d6f..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Any, Union - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool -from core.tools.tool.builtin_tool import BuiltinTool - - -class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]: - token = self.runtime.credentials.get("token", None) - if not token: - raise Exception("token is required") - return self.create_text_message(self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/toc")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml deleted file mode 100644 index 0a481b59eb..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml +++ /dev/null @@ -1,25 +0,0 @@ -identity: - name: aliyuque_describe_book_table_of_contents - author: 佐井 - label: - en_US: Get Book's Table of Contents - zh_Hans: 获取知识库的目录 - icon: icon.svg -description: - human: - en_US: Get Book's Table of Contents. - zh_Hans: 获取知识库的目录。 - llm: Get Book's Table of Contents. - -parameters: - - name: book_id - type: string - required: true - form: llm - label: - en_US: Book ID - zh_Hans: 知识库 ID - human_description: - en_US: Book ID. - zh_Hans: 知识库 ID。 - llm_description: Book ID. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py deleted file mode 100644 index 4b793cd61f..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py +++ /dev/null @@ -1,52 +0,0 @@ -import json -from typing import Any, Union -from urllib.parse import urlparse - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool -from core.tools.tool.builtin_tool import BuiltinTool - - -class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - new_params = {**tool_parameters} - token = new_params.pop("token") - if not token or token.lower() == "none": - token = self.runtime.credentials.get("token", None) - if not token: - raise Exception("token is required") - new_params = {**tool_parameters} - url = new_params.pop("url") - if not url or not url.startswith("http"): - raise Exception("url is not valid") - - parsed_url = urlparse(url) - path_parts = parsed_url.path.strip("/").split("/") - if len(path_parts) < 3: - raise Exception("url is not correct") - doc_id = path_parts[-1] - book_slug = path_parts[-2] - group_id = path_parts[-3] - - new_params["group_login"] = group_id - new_params["book_slug"] = book_slug - index_page = json.loads( - self.request("GET", token, new_params, "/api/v2/repos/{group_login}/{book_slug}/index_page") - ) - book_id = index_page.get("data", {}).get("book", {}).get("id") - if not book_id: - raise Exception(f"can not parse book_id from {index_page}") - new_params["book_id"] = book_id - new_params["id"] = doc_id - data = self.request("GET", token, new_params, "/api/v2/repos/{book_id}/docs/{id}") - data = json.loads(data) - body_only = tool_parameters.get("body_only") or "" - if body_only.lower() == "true": - return self.create_text_message(data.get("data").get("body")) - else: - raw = data.get("data") - del raw["body_lake"] - del raw["body_html"] - return self.create_text_message(json.dumps(data)) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py deleted file mode 100644 index 7a45684bed..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Any, Union - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool -from core.tools.tool.builtin_tool import BuiltinTool - - -class AliYuqueDescribeDocumentsTool(AliYuqueTool, BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - token = self.runtime.credentials.get("token", None) - if not token: - raise Exception("token is required") - return self.create_text_message( - self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") - ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml deleted file mode 100644 index 0b14c1afba..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml +++ /dev/null @@ -1,38 +0,0 @@ -identity: - name: aliyuque_describe_documents - author: 佐井 - label: - en_US: Get Doc Detail - zh_Hans: 获取文档详情 - icon: icon.svg - -description: - human: - en_US: Retrieves detailed information of a specific document identified by its ID or path within a knowledge base. - zh_Hans: 根据知识库ID和文档ID或路径获取文档详细信息。 - llm: Fetches detailed doc info using ID/path from a knowledge base; supports doc lookup in Yuque. - -parameters: - - name: book_id - type: string - required: true - form: llm - label: - en_US: Knowledge Base ID - zh_Hans: 知识库 ID - human_description: - en_US: Identifier for the knowledge base where the document resides. - zh_Hans: 文档所属知识库的唯一标识。 - llm_description: ID of the knowledge base holding the document. - - - name: id - type: string - required: true - form: llm - label: - en_US: Document ID or Path - zh_Hans: 文档 ID 或路径 - human_description: - en_US: The unique identifier or path of the document to retrieve. - zh_Hans: 需要获取的文档的ID或其在知识库中的路径。 - llm_description: Unique doc ID or its path for retrieval. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py deleted file mode 100644 index ca0a3909f8..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Any, Union - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool -from core.tools.tool.builtin_tool import BuiltinTool - - -class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]: - token = self.runtime.credentials.get("token", None) - if not token: - raise Exception("token is required") - - doc_ids = tool_parameters.get("doc_ids") - if doc_ids: - doc_ids = [int(doc_id.strip()) for doc_id in doc_ids.split(",")] - tool_parameters["doc_ids"] = doc_ids - - return self.create_text_message(self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/toc")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml deleted file mode 100644 index f85970348b..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml +++ /dev/null @@ -1,222 +0,0 @@ -identity: - name: aliyuque_update_book_table_of_contents - author: 佐井 - label: - en_US: Update Book's Table of Contents - zh_Hans: 更新知识库目录 - icon: icon.svg -description: - human: - en_US: Update Book's Table of Contents. - zh_Hans: 更新知识库目录。 - llm: Update Book's Table of Contents. - -parameters: - - name: book_id - type: string - required: true - form: llm - label: - en_US: Book ID - zh_Hans: 知识库 ID - human_description: - en_US: Book ID. - zh_Hans: 知识库 ID。 - llm_description: Book ID. - - - name: action - type: select - required: true - form: llm - options: - - value: appendNode - label: - en_US: appendNode - zh_Hans: appendNode - pt_BR: appendNode - - value: prependNode - label: - en_US: prependNode - zh_Hans: prependNode - pt_BR: prependNode - - value: editNode - label: - en_US: editNode - zh_Hans: editNode - pt_BR: editNode - - value: editNode - label: - en_US: removeNode - zh_Hans: removeNode - pt_BR: removeNode - label: - en_US: Action Type - zh_Hans: 操作 - human_description: - en_US: In the operation scenario, sibling node prepending is not supported, deleting a node doesn't remove associated documents, and node deletion has two modes, 'sibling' (delete current node) and 'child' (delete current node and its children). - zh_Hans: 操作,创建场景下不支持同级头插 prependNode,删除节点不会删除关联文档,删除节点时action_mode=sibling (删除当前节点), action_mode=child (删除当前节点及子节点) - llm_description: In the operation scenario, sibling node prepending is not supported, deleting a node doesn't remove associated documents, and node deletion has two modes, 'sibling' (delete current node) and 'child' (delete current node and its children). - - - - name: action_mode - type: select - required: false - form: llm - options: - - value: sibling - label: - en_US: sibling - zh_Hans: 同级 - pt_BR: sibling - - value: child - label: - en_US: child - zh_Hans: 子集 - pt_BR: child - label: - en_US: Action Type - zh_Hans: 操作 - human_description: - en_US: Operation mode (sibling:same level, child:child level). - zh_Hans: 操作模式 (sibling:同级, child:子级)。 - llm_description: Operation mode (sibling:same level, child:child level). - - - name: target_uuid - type: string - required: false - form: llm - label: - en_US: Target node UUID - zh_Hans: 目标节点 UUID - human_description: - en_US: Target node UUID, defaults to root node if left empty. - zh_Hans: 目标节点 UUID, 不填默认为根节点。 - llm_description: Target node UUID, defaults to root node if left empty. - - - name: node_uuid - type: string - required: false - form: llm - label: - en_US: Node UUID - zh_Hans: 操作节点 UUID - human_description: - en_US: Operation node UUID [required for move/update/delete]. - zh_Hans: 操作节点 UUID [移动/更新/删除必填]。 - llm_description: Operation node UUID [required for move/update/delete]. - - - name: doc_ids - type: string - required: false - form: llm - label: - en_US: Document IDs - zh_Hans: 文档id列表 - human_description: - en_US: Document IDs [required for creating documents], separate multiple IDs with ','. - zh_Hans: 文档 IDs [创建文档必填],多个用','分隔。 - llm_description: Document IDs [required for creating documents], separate multiple IDs with ','. - - - - name: type - type: select - required: false - form: llm - default: DOC - options: - - value: DOC - label: - en_US: DOC - zh_Hans: 文档 - pt_BR: DOC - - value: LINK - label: - en_US: LINK - zh_Hans: 链接 - pt_BR: LINK - - value: TITLE - label: - en_US: TITLE - zh_Hans: 分组 - pt_BR: TITLE - label: - en_US: Node type - zh_Hans: 操节点类型 - human_description: - en_US: Node type [required for creation] (DOC:document, LINK:external link, TITLE:group). - zh_Hans: 操节点类型 [创建必填] (DOC:文档, LINK:外链, TITLE:分组)。 - llm_description: Node type [required for creation] (DOC:document, LINK:external link, TITLE:group). - - - name: title - type: string - required: false - form: llm - label: - en_US: Node Name - zh_Hans: 节点名称 - human_description: - en_US: Node name [required for creating groups/external links]. - zh_Hans: 节点名称 [创建分组/外链必填]。 - llm_description: Node name [required for creating groups/external links]. - - - name: url - type: string - required: false - form: llm - label: - en_US: Node URL - zh_Hans: 节点URL - human_description: - en_US: Node URL [required for creating external links]. - zh_Hans: 节点 URL [创建外链必填]。 - llm_description: Node URL [required for creating external links]. - - - - name: open_window - type: select - required: false - form: llm - default: 0 - options: - - value: 0 - label: - en_US: DOC - zh_Hans: Current Page - pt_BR: DOC - - value: 1 - label: - en_US: LINK - zh_Hans: New Page - pt_BR: LINK - label: - en_US: Open in new window - zh_Hans: 是否新窗口打开 - human_description: - en_US: Open in new window [optional for external links] (0:open in current page, 1:open in new window). - zh_Hans: 是否新窗口打开 [外链选填] (0:当前页打开, 1:新窗口打开)。 - llm_description: Open in new window [optional for external links] (0:open in current page, 1:open in new window). - - - - name: visible - type: select - required: false - form: llm - default: 1 - options: - - value: 0 - label: - en_US: Invisible - zh_Hans: 隐藏 - pt_BR: Invisible - - value: 1 - label: - en_US: Visible - zh_Hans: 可见 - pt_BR: Visible - label: - en_US: Visibility - zh_Hans: 是否可见 - human_description: - en_US: Visibility (0:invisible, 1:visible). - zh_Hans: 是否可见 (0:不可见, 1:可见)。 - llm_description: Visibility (0:invisible, 1:visible). diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_document.py b/api/core/tools/provider/builtin/aliyuque/tools/update_document.py deleted file mode 100644 index d7eba46ad9..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/update_document.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Any, Union - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool -from core.tools.tool.builtin_tool import BuiltinTool - - -class AliYuqueUpdateDocumentTool(AliYuqueTool, BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - token = self.runtime.credentials.get("token", None) - if not token: - raise Exception("token is required") - return self.create_text_message( - self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") - ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml deleted file mode 100644 index c2da6b179a..0000000000 --- a/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml +++ /dev/null @@ -1,87 +0,0 @@ -identity: - name: aliyuque_update_document - author: 佐井 - label: - en_US: Update Document - zh_Hans: 更新文档 - icon: icon.svg -description: - human: - en_US: Update an existing document within a specified knowledge base by providing the document ID or path. - zh_Hans: 通过提供文档ID或路径,更新指定知识库中的现有文档。 - llm: Update doc in a knowledge base via ID/path. -parameters: - - name: book_id - type: string - required: true - form: llm - label: - en_US: Knowledge Base ID - zh_Hans: 知识库 ID - human_description: - en_US: The unique identifier of the knowledge base where the document resides. - zh_Hans: 文档所属知识库的ID。 - llm_description: ID of the knowledge base holding the doc. - - name: id - type: string - required: true - form: llm - label: - en_US: Document ID or Path - zh_Hans: 文档 ID 或 路径 - human_description: - en_US: The unique identifier or the path of the document to be updated. - zh_Hans: 要更新的文档的唯一ID或路径。 - llm_description: Doc's ID or path for update. - - - name: title - type: string - required: false - form: llm - label: - en_US: Title - zh_Hans: 标题 - human_description: - en_US: The title of the document, defaults to 'Untitled' if not provided. - zh_Hans: 文档标题,默认为'无标题'如未提供。 - llm_description: Title of the document, defaults to 'Untitled'. - - - name: format - type: select - required: false - form: llm - options: - - value: markdown - label: - en_US: markdown - zh_Hans: markdown - pt_BR: markdown - - value: html - label: - en_US: html - zh_Hans: html - pt_BR: html - - value: lake - label: - en_US: lake - zh_Hans: lake - pt_BR: lake - label: - en_US: Content Format - zh_Hans: 内容格式 - human_description: - en_US: Format of the document content (markdown, HTML, Lake). - zh_Hans: 文档内容格式(markdown, HTML, Lake)。 - llm_description: Content format choices, markdown, HTML, Lake. - - - name: body - type: string - required: true - form: llm - label: - en_US: Body Content - zh_Hans: 正文内容 - human_description: - en_US: The actual content of the document. - zh_Hans: 文档的实际内容。 - llm_description: Content of the document. diff --git a/api/core/tools/provider/builtin/baidu_translate/_assets/icon.png b/api/core/tools/provider/builtin/baidu_translate/_assets/icon.png deleted file mode 100644 index 8eb8f21513..0000000000 Binary files a/api/core/tools/provider/builtin/baidu_translate/_assets/icon.png and /dev/null differ diff --git a/api/core/tools/provider/builtin/baidu_translate/_baidu_translate_tool_base.py b/api/core/tools/provider/builtin/baidu_translate/_baidu_translate_tool_base.py deleted file mode 100644 index ce907c3c61..0000000000 --- a/api/core/tools/provider/builtin/baidu_translate/_baidu_translate_tool_base.py +++ /dev/null @@ -1,11 +0,0 @@ -from hashlib import md5 - - -class BaiduTranslateToolBase: - def _get_sign(self, appid, secret, salt, query): - """ - get baidu translate sign - """ - # concatenate the string in the order of appid+q+salt+secret - str = appid + query + salt + secret - return md5(str.encode("utf-8")).hexdigest() diff --git a/api/core/tools/provider/builtin/baidu_translate/baidu_translate.py b/api/core/tools/provider/builtin/baidu_translate/baidu_translate.py deleted file mode 100644 index cccd2f8c8f..0000000000 --- a/api/core/tools/provider/builtin/baidu_translate/baidu_translate.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Any - -from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.baidu_translate.tools.translate import BaiduTranslateTool -from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController - - -class BaiduTranslateProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict[str, Any]) -> None: - try: - BaiduTranslateTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke(user_id="", tool_parameters={"q": "这是一段测试文本", "from": "auto", "to": "en"}) - except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/baidu_translate/baidu_translate.yaml b/api/core/tools/provider/builtin/baidu_translate/baidu_translate.yaml deleted file mode 100644 index 06dadeeefc..0000000000 --- a/api/core/tools/provider/builtin/baidu_translate/baidu_translate.yaml +++ /dev/null @@ -1,39 +0,0 @@ -identity: - author: Xiao Ley - name: baidu_translate - label: - en_US: Baidu Translate - zh_Hans: 百度翻译 - description: - en_US: Translate text using Baidu - zh_Hans: 使用百度进行翻译 - icon: icon.png - tags: - - utilities -credentials_for_provider: - appid: - type: secret-input - required: true - label: - en_US: Baidu translate appid - zh_Hans: Baidu translate appid - placeholder: - en_US: Please input your Baidu translate appid - zh_Hans: 请输入你的百度翻译 appid - help: - en_US: Get your Baidu translate appid from Baidu translate - zh_Hans: 从百度翻译开放平台获取你的 appid - url: https://api.fanyi.baidu.com - secret: - type: secret-input - required: true - label: - en_US: Baidu translate secret - zh_Hans: Baidu translate secret - placeholder: - en_US: Please input your Baidu translate secret - zh_Hans: 请输入你的百度翻译 secret - help: - en_US: Get your Baidu translate secret from Baidu translate - zh_Hans: 从百度翻译开放平台获取你的 secret - url: https://api.fanyi.baidu.com diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.py b/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.py deleted file mode 100644 index bce259f31d..0000000000 --- a/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.py +++ /dev/null @@ -1,78 +0,0 @@ -import random -from hashlib import md5 -from typing import Any, Union - -import requests - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.baidu_translate._baidu_translate_tool_base import BaiduTranslateToolBase -from core.tools.tool.builtin_tool import BuiltinTool - - -class BaiduFieldTranslateTool(BuiltinTool, BaiduTranslateToolBase): - def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - """ - invoke tools - """ - BAIDU_FIELD_TRANSLATE_URL = "https://fanyi-api.baidu.com/api/trans/vip/fieldtranslate" - - appid = self.runtime.credentials.get("appid", "") - if not appid: - raise ValueError("invalid baidu translate appid") - - secret = self.runtime.credentials.get("secret", "") - if not secret: - raise ValueError("invalid baidu translate secret") - - q = tool_parameters.get("q", "") - if not q: - raise ValueError("Please input text to translate") - - from_ = tool_parameters.get("from", "") - if not from_: - raise ValueError("Please select source language") - - to = tool_parameters.get("to", "") - if not to: - raise ValueError("Please select destination language") - - domain = tool_parameters.get("domain", "") - if not domain: - raise ValueError("Please select domain") - - salt = str(random.randint(32768, 16777215)) - sign = self._get_sign(appid, secret, salt, q, domain) - - headers = {"Content-Type": "application/x-www-form-urlencoded"} - params = { - "q": q, - "from": from_, - "to": to, - "appid": appid, - "salt": salt, - "domain": domain, - "sign": sign, - "needIntervene": 1, - } - try: - response = requests.post(BAIDU_FIELD_TRANSLATE_URL, headers=headers, data=params) - result = response.json() - - if "trans_result" in result: - result_text = result["trans_result"][0]["dst"] - else: - result_text = f'{result["error_code"]}: {result["error_msg"]}' - - return self.create_text_message(str(result_text)) - except requests.RequestException as e: - raise ValueError(f"Translation service error: {e}") - except Exception: - raise ValueError("Translation service error, please check the network") - - def _get_sign(self, appid, secret, salt, query, domain): - str = appid + query + salt + domain + secret - return md5(str.encode("utf-8")).hexdigest() diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.yaml b/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.yaml deleted file mode 100644 index de51fddbae..0000000000 --- a/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.yaml +++ /dev/null @@ -1,123 +0,0 @@ -identity: - name: field_translate - author: Xiao Ley - label: - en_US: Field translate - zh_Hans: 百度领域翻译 -description: - human: - en_US: A tool for Baidu Field translate (Currently, the fields of "novel" and "wiki" only support Chinese to English translation. If the language direction is set to English to Chinese, the default output will be a universal translation result). - zh_Hans: 百度领域翻译,提供多种领域的文本翻译(目前“网络文学领域”和“人文社科领域”仅支持中到英,如设置语言方向为英到中,则默认输出通用翻译结果) - llm: A tool for Baidu Field translate -parameters: - - name: q - type: string - required: true - label: - en_US: Text content - zh_Hans: 文本内容 - human_description: - en_US: Text content to be translated - zh_Hans: 需要翻译的文本内容 - llm_description: Text content to be translated - form: llm - - name: from - type: select - required: true - label: - en_US: source language - zh_Hans: 源语言 - human_description: - en_US: The source language of the input text - zh_Hans: 输入的文本的源语言 - default: auto - form: form - options: - - value: auto - label: - en_US: auto - zh_Hans: 自动检测 - - value: zh - label: - en_US: Chinese - zh_Hans: 中文 - - value: en - label: - en_US: English - zh_Hans: 英语 - - name: to - type: select - required: true - label: - en_US: destination language - zh_Hans: 目标语言 - human_description: - en_US: The destination language of the input text - zh_Hans: 输入文本的目标语言 - default: en - form: form - options: - - value: zh - label: - en_US: Chinese - zh_Hans: 中文 - - value: en - label: - en_US: English - zh_Hans: 英语 - - name: domain - type: select - required: true - label: - en_US: domain - zh_Hans: 领域 - human_description: - en_US: The domain of the input text - zh_Hans: 输入文本的领域 - default: novel - form: form - options: - - value: it - label: - en_US: it - zh_Hans: 信息技术领域 - - value: finance - label: - en_US: finance - zh_Hans: 金融财经领域 - - value: machinery - label: - en_US: machinery - zh_Hans: 机械制造领域 - - value: senimed - label: - en_US: senimed - zh_Hans: 生物医药领域 - - value: novel - label: - en_US: novel (only support Chinese to English translation) - zh_Hans: 网络文学领域(仅支持中到英) - - value: academic - label: - en_US: academic - zh_Hans: 学术论文领域 - - value: aerospace - label: - en_US: aerospace - zh_Hans: 航空航天领域 - - value: wiki - label: - en_US: wiki (only support Chinese to English translation) - zh_Hans: 人文社科领域(仅支持中到英) - - value: news - label: - en_US: news - zh_Hans: 新闻咨询领域 - - value: law - label: - en_US: law - zh_Hans: 法律法规领域 - - value: contract - label: - en_US: contract - zh_Hans: 合同领域 diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/language.py b/api/core/tools/provider/builtin/baidu_translate/tools/language.py deleted file mode 100644 index 3bbaee88b3..0000000000 --- a/api/core/tools/provider/builtin/baidu_translate/tools/language.py +++ /dev/null @@ -1,95 +0,0 @@ -import random -from typing import Any, Union - -import requests - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.baidu_translate._baidu_translate_tool_base import BaiduTranslateToolBase -from core.tools.tool.builtin_tool import BuiltinTool - - -class BaiduLanguageTool(BuiltinTool, BaiduTranslateToolBase): - def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - """ - invoke tools - """ - BAIDU_LANGUAGE_URL = "https://fanyi-api.baidu.com/api/trans/vip/language" - - appid = self.runtime.credentials.get("appid", "") - if not appid: - raise ValueError("invalid baidu translate appid") - - secret = self.runtime.credentials.get("secret", "") - if not secret: - raise ValueError("invalid baidu translate secret") - - q = tool_parameters.get("q", "") - if not q: - raise ValueError("Please input text to translate") - - description_language = tool_parameters.get("description_language", "English") - - salt = str(random.randint(32768, 16777215)) - sign = self._get_sign(appid, secret, salt, q) - - headers = {"Content-Type": "application/x-www-form-urlencoded"} - params = { - "q": q, - "appid": appid, - "salt": salt, - "sign": sign, - } - - try: - response = requests.post(BAIDU_LANGUAGE_URL, params=params, headers=headers) - result = response.json() - if "error_code" not in result: - raise ValueError("Translation service error, please check the network") - - result_text = "" - if result["error_code"] != 0: - result_text = f'{result["error_code"]}: {result["error_msg"]}' - else: - result_text = result["data"]["src"] - result_text = self.mapping_result(description_language, result_text) - - return self.create_text_message(result_text) - except requests.RequestException as e: - raise ValueError(f"Translation service error: {e}") - except Exception: - raise ValueError("Translation service error, please check the network") - - def mapping_result(self, description_language: str, result: str) -> str: - """ - mapping result - """ - mapping = { - "English": { - "zh": "Chinese", - "en": "English", - "jp": "Japanese", - "kor": "Korean", - "th": "Thai", - "vie": "Vietnamese", - "ru": "Russian", - }, - "Chinese": { - "zh": "中文", - "en": "英文", - "jp": "日文", - "kor": "韩文", - "th": "泰语", - "vie": "越南语", - "ru": "俄语", - }, - } - - language_mapping = mapping.get(description_language) - if not language_mapping: - return result - - return language_mapping.get(result, result) diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/language.yaml b/api/core/tools/provider/builtin/baidu_translate/tools/language.yaml deleted file mode 100644 index 60cca2e288..0000000000 --- a/api/core/tools/provider/builtin/baidu_translate/tools/language.yaml +++ /dev/null @@ -1,43 +0,0 @@ -identity: - name: language - author: Xiao Ley - label: - en_US: Baidu Language - zh_Hans: 百度语种识别 -description: - human: - en_US: A tool for Baidu Language, support Chinese, English, Japanese, Korean, Thai, Vietnamese and Russian - zh_Hans: 使用百度进行语种识别,支持的语种:中文、英语、日语、韩语、泰语、越南语和俄语 - llm: A tool for Baidu Language -parameters: - - name: q - type: string - required: true - label: - en_US: Text content - zh_Hans: 文本内容 - human_description: - en_US: Text content to be recognized - zh_Hans: 需要识别语言的文本内容 - llm_description: Text content to be recognized - form: llm - - name: description_language - type: select - required: true - label: - en_US: Description language - zh_Hans: 描述语言 - human_description: - en_US: Describe the language used to identify the results - zh_Hans: 描述识别结果所用的语言 - default: Chinese - form: form - options: - - value: Chinese - label: - en_US: Chinese - zh_Hans: 中文 - - value: English - label: - en_US: English - zh_Hans: 英语 diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/translate.py b/api/core/tools/provider/builtin/baidu_translate/tools/translate.py deleted file mode 100644 index 7cd816a3bc..0000000000 --- a/api/core/tools/provider/builtin/baidu_translate/tools/translate.py +++ /dev/null @@ -1,67 +0,0 @@ -import random -from typing import Any, Union - -import requests - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.baidu_translate._baidu_translate_tool_base import BaiduTranslateToolBase -from core.tools.tool.builtin_tool import BuiltinTool - - -class BaiduTranslateTool(BuiltinTool, BaiduTranslateToolBase): - def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - """ - invoke tools - """ - BAIDU_TRANSLATE_URL = "https://fanyi-api.baidu.com/api/trans/vip/translate" - - appid = self.runtime.credentials.get("appid", "") - if not appid: - raise ValueError("invalid baidu translate appid") - - secret = self.runtime.credentials.get("secret", "") - if not secret: - raise ValueError("invalid baidu translate secret") - - q = tool_parameters.get("q", "") - if not q: - raise ValueError("Please input text to translate") - - from_ = tool_parameters.get("from", "") - if not from_: - raise ValueError("Please select source language") - - to = tool_parameters.get("to", "") - if not to: - raise ValueError("Please select destination language") - - salt = str(random.randint(32768, 16777215)) - sign = self._get_sign(appid, secret, salt, q) - - headers = {"Content-Type": "application/x-www-form-urlencoded"} - params = { - "q": q, - "from": from_, - "to": to, - "appid": appid, - "salt": salt, - "sign": sign, - } - try: - response = requests.post(BAIDU_TRANSLATE_URL, params=params, headers=headers) - result = response.json() - - if "trans_result" in result: - result_text = result["trans_result"][0]["dst"] - else: - result_text = f'{result["error_code"]}: {result["error_msg"]}' - - return self.create_text_message(str(result_text)) - except requests.RequestException as e: - raise ValueError(f"Translation service error: {e}") - except Exception: - raise ValueError("Translation service error, please check the network") diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/translate.yaml b/api/core/tools/provider/builtin/baidu_translate/tools/translate.yaml deleted file mode 100644 index c8ff32cb6b..0000000000 --- a/api/core/tools/provider/builtin/baidu_translate/tools/translate.yaml +++ /dev/null @@ -1,275 +0,0 @@ -identity: - name: translate - author: Xiao Ley - label: - en_US: Translate - zh_Hans: 百度翻译 -description: - human: - en_US: A tool for Baidu Translate - zh_Hans: 百度翻译 - llm: A tool for Baidu Translate -parameters: - - name: q - type: string - required: true - label: - en_US: Text content - zh_Hans: 文本内容 - human_description: - en_US: Text content to be translated - zh_Hans: 需要翻译的文本内容 - llm_description: Text content to be translated - form: llm - - name: from - type: select - required: true - label: - en_US: source language - zh_Hans: 源语言 - human_description: - en_US: The source language of the input text - zh_Hans: 输入的文本的源语言 - default: auto - form: form - options: - - value: auto - label: - en_US: auto - zh_Hans: 自动检测 - - value: zh - label: - en_US: Chinese - zh_Hans: 中文 - - value: en - label: - en_US: English - zh_Hans: 英语 - - value: cht - label: - en_US: Traditional Chinese - zh_Hans: 繁体中文 - - value: yue - label: - en_US: Yue - zh_Hans: 粤语 - - value: wyw - label: - en_US: Wyw - zh_Hans: 文言文 - - value: jp - label: - en_US: Japanese - zh_Hans: 日语 - - value: kor - label: - en_US: Korean - zh_Hans: 韩语 - - value: fra - label: - en_US: French - zh_Hans: 法语 - - value: spa - label: - en_US: Spanish - zh_Hans: 西班牙语 - - value: th - label: - en_US: Thai - zh_Hans: 泰语 - - value: ara - label: - en_US: Arabic - zh_Hans: 阿拉伯语 - - value: ru - label: - en_US: Russian - zh_Hans: 俄语 - - value: pt - label: - en_US: Portuguese - zh_Hans: 葡萄牙语 - - value: de - label: - en_US: German - zh_Hans: 德语 - - value: it - label: - en_US: Italian - zh_Hans: 意大利语 - - value: el - label: - en_US: Greek - zh_Hans: 希腊语 - - value: nl - label: - en_US: Dutch - zh_Hans: 荷兰语 - - value: pl - label: - en_US: Polish - zh_Hans: 波兰语 - - value: bul - label: - en_US: Bulgarian - zh_Hans: 保加利亚语 - - value: est - label: - en_US: Estonian - zh_Hans: 爱沙尼亚语 - - value: dan - label: - en_US: Danish - zh_Hans: 丹麦语 - - value: fin - label: - en_US: Finnish - zh_Hans: 芬兰语 - - value: cs - label: - en_US: Czech - zh_Hans: 捷克语 - - value: rom - label: - en_US: Romanian - zh_Hans: 罗马尼亚语 - - value: slo - label: - en_US: Slovak - zh_Hans: 斯洛文尼亚语 - - value: swe - label: - en_US: Swedish - zh_Hans: 瑞典语 - - value: hu - label: - en_US: Hungarian - zh_Hans: 匈牙利语 - - value: vie - label: - en_US: Vietnamese - zh_Hans: 越南语 - - name: to - type: select - required: true - label: - en_US: destination language - zh_Hans: 目标语言 - human_description: - en_US: The destination language of the input text - zh_Hans: 输入文本的目标语言 - default: en - form: form - options: - - value: zh - label: - en_US: Chinese - zh_Hans: 中文 - - value: en - label: - en_US: English - zh_Hans: 英语 - - value: cht - label: - en_US: Traditional Chinese - zh_Hans: 繁体中文 - - value: yue - label: - en_US: Yue - zh_Hans: 粤语 - - value: wyw - label: - en_US: Wyw - zh_Hans: 文言文 - - value: jp - label: - en_US: Japanese - zh_Hans: 日语 - - value: kor - label: - en_US: Korean - zh_Hans: 韩语 - - value: fra - label: - en_US: French - zh_Hans: 法语 - - value: spa - label: - en_US: Spanish - zh_Hans: 西班牙语 - - value: th - label: - en_US: Thai - zh_Hans: 泰语 - - value: ara - label: - en_US: Arabic - zh_Hans: 阿拉伯语 - - value: ru - label: - en_US: Russian - zh_Hans: 俄语 - - value: pt - label: - en_US: Portuguese - zh_Hans: 葡萄牙语 - - value: de - label: - en_US: German - zh_Hans: 德语 - - value: it - label: - en_US: Italian - zh_Hans: 意大利语 - - value: el - label: - en_US: Greek - zh_Hans: 希腊语 - - value: nl - label: - en_US: Dutch - zh_Hans: 荷兰语 - - value: pl - label: - en_US: Polish - zh_Hans: 波兰语 - - value: bul - label: - en_US: Bulgarian - zh_Hans: 保加利亚语 - - value: est - label: - en_US: Estonian - zh_Hans: 爱沙尼亚语 - - value: dan - label: - en_US: Danish - zh_Hans: 丹麦语 - - value: fin - label: - en_US: Finnish - zh_Hans: 芬兰语 - - value: cs - label: - en_US: Czech - zh_Hans: 捷克语 - - value: rom - label: - en_US: Romanian - zh_Hans: 罗马尼亚语 - - value: slo - label: - en_US: Slovak - zh_Hans: 斯洛文尼亚语 - - value: swe - label: - en_US: Swedish - zh_Hans: 瑞典语 - - value: hu - label: - en_US: Hungarian - zh_Hans: 匈牙利语 - - value: vie - label: - en_US: Vietnamese - zh_Hans: 越南语 diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py deleted file mode 100644 index 209d6ecba4..0000000000 --- a/api/core/tools/provider/builtin/chart/chart.py +++ /dev/null @@ -1,36 +0,0 @@ -import matplotlib.pyplot as plt -from matplotlib.font_manager import FontProperties - -from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController - - -def set_chinese_font(): - font_list = [ - "PingFang SC", - "SimHei", - "Microsoft YaHei", - "STSong", - "SimSun", - "Arial Unicode MS", - "Noto Sans CJK SC", - "Noto Sans CJK JP", - ] - - for font in font_list: - chinese_font = FontProperties(font) - if chinese_font.get_name() == font: - return chinese_font - - return FontProperties() - - -# use a business theme -plt.style.use("seaborn-v0_8-darkgrid") -plt.rcParams["axes.unicode_minus"] = False -font_properties = set_chinese_font() -plt.rcParams["font.family"] = font_properties.get_name() - - -class ChartProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict) -> None: - pass diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py deleted file mode 100644 index d4bf713441..0000000000 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py +++ /dev/null @@ -1,114 +0,0 @@ -import base64 -import io -import json -import random -import uuid - -import httpx -from websocket import WebSocket -from yarl import URL - -from core.file.file_manager import _get_encoded_string -from core.file.models import File - - -class ComfyUiClient: - def __init__(self, base_url: str): - self.base_url = URL(base_url) - - def get_history(self, prompt_id: str) -> dict: - res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id}) - history = res.json()[prompt_id] - return history - - def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes: - response = httpx.get( - str(self.base_url / "view"), - params={"filename": filename, "subfolder": subfolder, "type": folder_type}, - ) - return response.content - - def upload_image(self, image_file: File) -> dict: - image_content = base64.b64decode(_get_encoded_string(image_file)) - file = io.BytesIO(image_content) - files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"} - res = httpx.post(str(self.base_url / "upload/image"), files=files) - return res.json() - - def queue_prompt(self, client_id: str, prompt: dict) -> str: - res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt}) - prompt_id = res.json()["prompt_id"] - return prompt_id - - def open_websocket_connection(self) -> tuple[WebSocket, str]: - client_id = str(uuid.uuid4()) - ws = WebSocket() - ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}" - ws.connect(ws_address) - return ws, client_id - - def set_prompt( - self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = "" - ) -> dict: - """ - find the first KSampler, then can find the prompt node through it. - """ - prompt = origin_prompt.copy() - id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} - k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0] - prompt.get(k_sampler)["inputs"]["seed"] = random.randint(10**14, 10**15 - 1) - positive_input_id = prompt.get(k_sampler)["inputs"]["positive"][0] - prompt.get(positive_input_id)["inputs"]["text"] = positive_prompt - - if negative_prompt != "": - negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] - prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt - - if image_name != "": - image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0] - prompt.get(image_loader)["inputs"]["image"] = image_name - return prompt - - def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): - node_ids = list(prompt.keys()) - finished_nodes = [] - - while True: - out = ws.recv() - if isinstance(out, str): - message = json.loads(out) - if message["type"] == "progress": - data = message["data"] - current_step = data["value"] - print("In K-Sampler -> Step: ", current_step, " of: ", data["max"]) - if message["type"] == "execution_cached": - data = message["data"] - for itm in data["nodes"]: - if itm not in finished_nodes: - finished_nodes.append(itm) - print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") - if message["type"] == "executing": - data = message["data"] - if data["node"] not in finished_nodes: - finished_nodes.append(data["node"]) - print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") - - if data["node"] is None and data["prompt_id"] == prompt_id: - break # Execution is done - else: - continue - - def generate_image_by_prompt(self, prompt: dict) -> list[bytes]: - try: - ws, client_id = self.open_websocket_connection() - prompt_id = self.queue_prompt(client_id, prompt) - self.track_progress(prompt, ws, prompt_id) - history = self.get_history(prompt_id) - images = [] - for output in history["outputs"].values(): - for img in output.get("images", []): - image_data = self.get_image(img["filename"], img["subfolder"], img["type"]) - images.append(image_data) - return images - finally: - ws.close() diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py deleted file mode 100644 index 11320d5d0f..0000000000 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py +++ /dev/null @@ -1,34 +0,0 @@ -import json -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient -from core.tools.tool.builtin_tool import BuiltinTool - - -class ComfyUIWorkflowTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: - comfyui = ComfyUiClient(self.runtime.credentials["base_url"]) - - positive_prompt = tool_parameters.get("positive_prompt") - negative_prompt = tool_parameters.get("negative_prompt") - workflow = tool_parameters.get("workflow_json") - image_name = "" - if image := tool_parameters.get("image"): - image_name = comfyui.upload_image(image).get("name") - - try: - origin_prompt = json.loads(workflow) - except: - return self.create_text_message("the Workflow JSON is not correct") - - prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name) - images = comfyui.generate_image_by_prompt(prompt) - result = [] - for img in images: - result.append( - self.create_blob_message( - blob=img, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value - ) - ) - return result diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml deleted file mode 100644 index 55fcdad825..0000000000 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml +++ /dev/null @@ -1,42 +0,0 @@ -identity: - name: workflow - author: hjlarry - label: - en_US: workflow - zh_Hans: 工作流 -description: - human: - en_US: Run ComfyUI workflow. - zh_Hans: 运行ComfyUI工作流。 - llm: Run ComfyUI workflow. -parameters: - - name: positive_prompt - type: string - label: - en_US: Prompt - zh_Hans: 提示词 - llm_description: Image prompt, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. - form: llm - - name: negative_prompt - type: string - label: - en_US: Negative Prompt - zh_Hans: 负面提示词 - llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English. - form: llm - - name: image - type: file - label: - en_US: Input Image - zh_Hans: 输入的图片 - llm_description: The input image, used to transfer to the comfyui workflow to generate another image. - form: llm - - name: workflow_json - type: string - required: true - label: - en_US: Workflow JSON - human_description: - en_US: exported from ComfyUI workflow - zh_Hans: 从ComfyUI的工作流中导出 - form: form diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py deleted file mode 100644 index 54bb38755a..0000000000 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any - -from duckduckgo_search import DDGS - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class DuckDuckGoImageSearchTool(BuiltinTool): - """ - Tool for performing an image search using DuckDuckGo search engine. - """ - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: - query_dict = { - "keywords": tool_parameters.get("query"), - "timelimit": tool_parameters.get("timelimit"), - "size": tool_parameters.get("size"), - "max_results": tool_parameters.get("max_results"), - } - response = DDGS().images(**query_dict) - markdown_result = "\n\n" - json_result = [] - for res in response: - markdown_result += f"![{res.get('title') or ''}]({res.get('image') or ''})" - json_result.append(self.create_json_message(res)) - return [self.create_text_message(markdown_result)] + json_result diff --git a/api/core/tools/provider/builtin/feishu_base/_assets/icon.png b/api/core/tools/provider/builtin/feishu_base/_assets/icon.png deleted file mode 100644 index 787427e721..0000000000 Binary files a/api/core/tools/provider/builtin/feishu_base/_assets/icon.png and /dev/null differ diff --git a/api/core/tools/provider/builtin/feishu_base/feishu_base.py b/api/core/tools/provider/builtin/feishu_base/feishu_base.py deleted file mode 100644 index f301ec5355..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/feishu_base.py +++ /dev/null @@ -1,7 +0,0 @@ -from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController -from core.tools.utils.feishu_api_utils import auth - - -class FeishuBaseProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict) -> None: - auth(credentials) diff --git a/api/core/tools/provider/builtin/feishu_base/feishu_base.yaml b/api/core/tools/provider/builtin/feishu_base/feishu_base.yaml deleted file mode 100644 index 456dd8c88f..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/feishu_base.yaml +++ /dev/null @@ -1,36 +0,0 @@ -identity: - author: Doug Lea - name: feishu_base - label: - en_US: Feishu Base - zh_Hans: 飞书多维表格 - description: - en_US: | - Feishu base, requires the following permissions: bitable:app. - zh_Hans: | - 飞书多维表格,需要开通以下权限: bitable:app。 - icon: icon.png - tags: - - social - - productivity -credentials_for_provider: - app_id: - type: text-input - required: true - label: - en_US: APP ID - placeholder: - en_US: Please input your feishu app id - zh_Hans: 请输入你的飞书 app id - help: - en_US: Get your app_id and app_secret from Feishu - zh_Hans: 从飞书获取您的 app_id 和 app_secret - url: https://open.larkoffice.com/app - app_secret: - type: secret-input - required: true - label: - en_US: APP Secret - placeholder: - en_US: Please input your app secret - zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_records.py b/api/core/tools/provider/builtin/feishu_base/tools/add_records.py deleted file mode 100644 index 905f8b7880..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/add_records.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.feishu_api_utils import FeishuRequest - - -class AddRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get("app_id") - app_secret = self.runtime.credentials.get("app_secret") - client = FeishuRequest(app_id, app_secret) - - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - records = tool_parameters.get("records") - user_id_type = tool_parameters.get("user_id_type", "open_id") - - res = client.add_records(app_token, table_id, table_name, records, user_id_type) - return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/add_records.yaml deleted file mode 100644 index f2a93490dc..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/add_records.yaml +++ /dev/null @@ -1,91 +0,0 @@ -identity: - name: add_records - author: Doug Lea - label: - en_US: Add Records - zh_Hans: 新增多条记录 -description: - human: - en_US: Add Multiple Records to Multidimensional Table - zh_Hans: 在多维表格数据表中新增多条记录 - llm: A tool for adding multiple records to a multidimensional table. (在多维表格数据表中新增多条记录) -parameters: - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: app_token - human_description: - en_US: Unique identifier for the multidimensional table, supports inputting document URL. - zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 - llm_description: 多维表格的唯一标识符,支持输入文档 URL。 - form: llm - - - name: table_id - type: string - required: false - label: - en_US: table_id - zh_Hans: table_id - human_description: - en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. - zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 - llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 - form: llm - - - name: table_name - type: string - required: false - label: - en_US: table_name - zh_Hans: table_name - human_description: - en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. - zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 - llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 - form: llm - - - name: records - type: string - required: true - label: - en_US: records - zh_Hans: 记录列表 - human_description: - en_US: | - List of records to be added in this request. Example value: [{"multi-line-text":"text content","single_select":"option 1","date":1674206443000}] - For supported field types, refer to the integration guide (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification). For data structures of different field types, refer to the data structure overview (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure). - zh_Hans: | - 本次请求将要新增的记录列表,示例值:[{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000}]。 - 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 - llm_description: | - 本次请求将要新增的记录列表,示例值:[{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000}]。 - 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 - form: llm - - - name: user_id_type - type: select - required: false - options: - - value: open_id - label: - en_US: open_id - zh_Hans: open_id - - value: union_id - label: - en_US: union_id - zh_Hans: union_id - - value: user_id - label: - en_US: user_id - zh_Hans: user_id - default: "open_id" - label: - en_US: user_id_type - zh_Hans: 用户 ID 类型 - human_description: - en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. - zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 - llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 - form: form diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base.py deleted file mode 100644 index f074acc5ff..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.feishu_api_utils import FeishuRequest - - -class CreateBaseTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get("app_id") - app_secret = self.runtime.credentials.get("app_secret") - client = FeishuRequest(app_id, app_secret) - - name = tool_parameters.get("name") - folder_token = tool_parameters.get("folder_token") - - res = client.create_base(name, folder_token) - return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base.yaml b/api/core/tools/provider/builtin/feishu_base/tools/create_base.yaml deleted file mode 100644 index 3ec91a90e7..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base.yaml +++ /dev/null @@ -1,42 +0,0 @@ -identity: - name: create_base - author: Doug Lea - label: - en_US: Create Base - zh_Hans: 创建多维表格 -description: - human: - en_US: Create Multidimensional Table in Specified Directory - zh_Hans: 在指定目录下创建多维表格 - llm: A tool for creating a multidimensional table in a specified directory. (在指定目录下创建多维表格) -parameters: - - name: name - type: string - required: false - label: - en_US: name - zh_Hans: 多维表格 App 名字 - human_description: - en_US: | - Name of the multidimensional table App. Example value: "A new multidimensional table". - zh_Hans: 多维表格 App 名字,示例值:"一篇新的多维表格"。 - llm_description: 多维表格 App 名字,示例值:"一篇新的多维表格"。 - form: llm - - - name: folder_token - type: string - required: false - label: - en_US: folder_token - zh_Hans: 多维表格 App 归属文件夹 - human_description: - en_US: | - Folder where the multidimensional table App belongs. Default is empty, meaning the table will be created in the root directory of the cloud space. Example values: Fa3sfoAgDlMZCcdcJy1cDFg8nJc or https://svi136aogf123.feishu.cn/drive/folder/Fa3sfoAgDlMZCcdcJy1cDFg8nJc. - The folder_token must be an existing folder and supports inputting folder token or folder URL. - zh_Hans: | - 多维表格 App 归属文件夹。默认为空,表示多维表格将被创建在云空间根目录。示例值: Fa3sfoAgDlMZCcdcJy1cDFg8nJc 或者 https://svi136aogf123.feishu.cn/drive/folder/Fa3sfoAgDlMZCcdcJy1cDFg8nJc。 - folder_token 必须是已存在的文件夹,支持输入文件夹 token 或者文件夹 URL。 - llm_description: | - 多维表格 App 归属文件夹。默认为空,表示多维表格将被创建在云空间根目录。示例值: Fa3sfoAgDlMZCcdcJy1cDFg8nJc 或者 https://svi136aogf123.feishu.cn/drive/folder/Fa3sfoAgDlMZCcdcJy1cDFg8nJc。 - folder_token 必须是已存在的文件夹,支持输入文件夹 token 或者文件夹 URL。 - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_table.py b/api/core/tools/provider/builtin/feishu_base/tools/create_table.py deleted file mode 100644 index 81f2617545..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_table.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.feishu_api_utils import FeishuRequest - - -class CreateTableTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get("app_id") - app_secret = self.runtime.credentials.get("app_secret") - client = FeishuRequest(app_id, app_secret) - - app_token = tool_parameters.get("app_token") - table_name = tool_parameters.get("table_name") - default_view_name = tool_parameters.get("default_view_name") - fields = tool_parameters.get("fields") - - res = client.create_table(app_token, table_name, default_view_name, fields) - return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_table.yaml b/api/core/tools/provider/builtin/feishu_base/tools/create_table.yaml deleted file mode 100644 index 8b1007b9a5..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_table.yaml +++ /dev/null @@ -1,61 +0,0 @@ -identity: - name: create_table - author: Doug Lea - label: - en_US: Create Table - zh_Hans: 新增数据表 -description: - human: - en_US: Add a Data Table to Multidimensional Table - zh_Hans: 在多维表格中新增一个数据表 - llm: A tool for adding a data table to a multidimensional table. (在多维表格中新增一个数据表) -parameters: - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: app_token - human_description: - en_US: Unique identifier for the multidimensional table, supports inputting document URL. - zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 - llm_description: 多维表格的唯一标识符,支持输入文档 URL。 - form: llm - - - name: table_name - type: string - required: true - label: - en_US: Table Name - zh_Hans: 数据表名称 - human_description: - en_US: | - The name of the data table, length range: 1 character to 100 characters. - zh_Hans: 数据表名称,长度范围:1 字符 ~ 100 字符。 - llm_description: 数据表名称,长度范围:1 字符 ~ 100 字符。 - form: llm - - - name: default_view_name - type: string - required: false - label: - en_US: Default View Name - zh_Hans: 默认表格视图的名称 - human_description: - en_US: The name of the default table view, defaults to "Table" if not filled. - zh_Hans: 默认表格视图的名称,不填则默认为"表格"。 - llm_description: 默认表格视图的名称,不填则默认为"表格"。 - form: llm - - - name: fields - type: string - required: true - label: - en_US: Initial Fields - zh_Hans: 初始字段 - human_description: - en_US: | - Initial fields of the data table, format: [ { "field_name": "Multi-line Text","type": 1 },{ "field_name": "Number","type": 2 },{ "field_name": "Single Select","type": 3 },{ "field_name": "Multiple Select","type": 4 },{ "field_name": "Date","type": 5 } ]. For field details, refer to: https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide - zh_Hans: 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。字段详情参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide - llm_description: 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。字段详情参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_records.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_records.py deleted file mode 100644 index c896a2c81b..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_records.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.feishu_api_utils import FeishuRequest - - -class DeleteRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get("app_id") - app_secret = self.runtime.credentials.get("app_secret") - client = FeishuRequest(app_id, app_secret) - - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - record_ids = tool_parameters.get("record_ids") - - res = client.delete_records(app_token, table_id, table_name, record_ids) - return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_records.yaml deleted file mode 100644 index c30ebd630c..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_records.yaml +++ /dev/null @@ -1,86 +0,0 @@ -identity: - name: delete_records - author: Doug Lea - label: - en_US: Delete Records - zh_Hans: 删除多条记录 -description: - human: - en_US: Delete Multiple Records from Multidimensional Table - zh_Hans: 删除多维表格数据表中的多条记录 - llm: A tool for deleting multiple records from a multidimensional table. (删除多维表格数据表中的多条记录) -parameters: - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: app_token - human_description: - en_US: Unique identifier for the multidimensional table, supports inputting document URL. - zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 - llm_description: 多维表格的唯一标识符,支持输入文档 URL。 - form: llm - - - name: table_id - type: string - required: false - label: - en_US: table_id - zh_Hans: table_id - human_description: - en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. - zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 - llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 - form: llm - - - name: table_name - type: string - required: false - label: - en_US: table_name - zh_Hans: table_name - human_description: - en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. - zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 - llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 - form: llm - - - name: record_ids - type: string - required: true - label: - en_US: Record IDs - zh_Hans: 记录 ID 列表 - human_description: - en_US: | - List of IDs for the records to be deleted, example value: ["recwNXzPQv"]. - zh_Hans: 删除的多条记录 ID 列表,示例值:["recwNXzPQv"]。 - llm_description: 删除的多条记录 ID 列表,示例值:["recwNXzPQv"]。 - form: llm - - - name: user_id_type - type: select - required: false - options: - - value: open_id - label: - en_US: open_id - zh_Hans: open_id - - value: union_id - label: - en_US: union_id - zh_Hans: union_id - - value: user_id - label: - en_US: user_id - zh_Hans: user_id - default: "open_id" - label: - en_US: user_id_type - zh_Hans: 用户 ID 类型 - human_description: - en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. - zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 - llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 - form: form diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.py deleted file mode 100644 index f732a16da6..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.feishu_api_utils import FeishuRequest - - -class DeleteTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get("app_id") - app_secret = self.runtime.credentials.get("app_secret") - client = FeishuRequest(app_id, app_secret) - - app_token = tool_parameters.get("app_token") - table_ids = tool_parameters.get("table_ids") - table_names = tool_parameters.get("table_names") - - res = client.delete_tables(app_token, table_ids, table_names) - return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.yaml deleted file mode 100644 index 498126eae5..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.yaml +++ /dev/null @@ -1,49 +0,0 @@ -identity: - name: delete_tables - author: Doug Lea - label: - en_US: Delete Tables - zh_Hans: 删除数据表 -description: - human: - en_US: Batch Delete Data Tables from Multidimensional Table - zh_Hans: 批量删除多维表格中的数据表 - llm: A tool for batch deleting data tables from a multidimensional table. (批量删除多维表格中的数据表) -parameters: - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: app_token - human_description: - en_US: Unique identifier for the multidimensional table, supports inputting document URL. - zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 - llm_description: 多维表格的唯一标识符,支持输入文档 URL。 - form: llm - - - name: table_ids - type: string - required: false - label: - en_US: Table IDs - zh_Hans: 数据表 ID - human_description: - en_US: | - IDs of the tables to be deleted. Each operation supports deleting up to 50 tables. Example: ["tbl1TkhyTWDkSoZ3"]. Ensure that either table_ids or table_names is not empty. - zh_Hans: 待删除的数据表的 ID,每次操作最多支持删除 50 个数据表。示例值:["tbl1TkhyTWDkSoZ3"]。请确保 table_ids 和 table_names 至少有一个不为空。 - llm_description: 待删除的数据表的 ID,每次操作最多支持删除 50 个数据表。示例值:["tbl1TkhyTWDkSoZ3"]。请确保 table_ids 和 table_names 至少有一个不为空。 - form: llm - - - name: table_names - type: string - required: false - label: - en_US: Table Names - zh_Hans: 数据表名称 - human_description: - en_US: | - Names of the tables to be deleted. Each operation supports deleting up to 50 tables. Example: ["Table1", "Table2"]. Ensure that either table_names or table_ids is not empty. - zh_Hans: 待删除的数据表的名称,每次操作最多支持删除 50 个数据表。示例值:["数据表1", "数据表2"]。请确保 table_names 和 table_ids 至少有一个不为空。 - llm_description: 待删除的数据表的名称,每次操作最多支持删除 50 个数据表。示例值:["数据表1", "数据表2"]。请确保 table_names 和 table_ids 至少有一个不为空。 - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py deleted file mode 100644 index a74e9be288..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.feishu_api_utils import FeishuRequest - - -class GetBaseInfoTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get("app_id") - app_secret = self.runtime.credentials.get("app_secret") - client = FeishuRequest(app_id, app_secret) - - app_token = tool_parameters.get("app_token") - - res = client.get_base_info(app_token) - return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.yaml b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.yaml deleted file mode 100644 index eb0e7a26c0..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.yaml +++ /dev/null @@ -1,23 +0,0 @@ -identity: - name: get_base_info - author: Doug Lea - label: - en_US: Get Base Info - zh_Hans: 获取多维表格元数据 -description: - human: - en_US: Get Metadata Information of Specified Multidimensional Table - zh_Hans: 获取指定多维表格的元数据信息 - llm: A tool for getting metadata information of a specified multidimensional table. (获取指定多维表格的元数据信息) -parameters: - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: app_token - human_description: - en_US: Unique identifier for the multidimensional table, supports inputting document URL. - zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 - llm_description: 多维表格的唯一标识符,支持输入文档 URL。 - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/list_tables.py deleted file mode 100644 index c7768a496d..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_tables.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.feishu_api_utils import FeishuRequest - - -class ListTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get("app_id") - app_secret = self.runtime.credentials.get("app_secret") - client = FeishuRequest(app_id, app_secret) - - app_token = tool_parameters.get("app_token") - page_token = tool_parameters.get("page_token") - page_size = tool_parameters.get("page_size", 20) - - res = client.list_tables(app_token, page_token, page_size) - return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/list_tables.yaml deleted file mode 100644 index 5a3891bd45..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_tables.yaml +++ /dev/null @@ -1,50 +0,0 @@ -identity: - name: list_tables - author: Doug Lea - label: - en_US: List Tables - zh_Hans: 列出数据表 -description: - human: - en_US: Get All Data Tables under Multidimensional Table - zh_Hans: 获取多维表格下的所有数据表 - llm: A tool for getting all data tables under a multidimensional table. (获取多维表格下的所有数据表) -parameters: - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: app_token - human_description: - en_US: Unique identifier for the multidimensional table, supports inputting document URL. - zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 - llm_description: 多维表格的唯一标识符,支持输入文档 URL。 - form: llm - - - name: page_size - type: number - required: false - default: 20 - label: - en_US: page_size - zh_Hans: 分页大小 - human_description: - en_US: | - Page size, default value: 20, maximum value: 100. - zh_Hans: 分页大小,默认值:20,最大值:100。 - llm_description: 分页大小,默认值:20,最大值:100。 - form: llm - - - name: page_token - type: string - required: false - label: - en_US: page_token - zh_Hans: 分页标记 - human_description: - en_US: | - Page token, leave empty for the first request to start from the beginning; a new page_token will be returned if there are more items in the paginated query results, which can be used for the next traversal. Example value: "tblsRc9GRRXKqhvW". - zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 - llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_records.py b/api/core/tools/provider/builtin/feishu_base/tools/read_records.py deleted file mode 100644 index 46f3df4ff0..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/read_records.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.feishu_api_utils import FeishuRequest - - -class ReadRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get("app_id") - app_secret = self.runtime.credentials.get("app_secret") - client = FeishuRequest(app_id, app_secret) - - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - record_ids = tool_parameters.get("record_ids") - user_id_type = tool_parameters.get("user_id_type", "open_id") - - res = client.read_records(app_token, table_id, table_name, record_ids, user_id_type) - return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/read_records.yaml deleted file mode 100644 index 911e667cfc..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/read_records.yaml +++ /dev/null @@ -1,86 +0,0 @@ -identity: - name: read_records - author: Doug Lea - label: - en_US: Read Records - zh_Hans: 批量获取记录 -description: - human: - en_US: Batch Retrieve Records from Multidimensional Table - zh_Hans: 批量获取多维表格数据表中的记录信息 - llm: A tool for batch retrieving records from a multidimensional table, supporting up to 100 records per call. (批量获取多维表格数据表中的记录信息,单次调用最多支持查询 100 条记录) - -parameters: - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: app_token - human_description: - en_US: Unique identifier for the multidimensional table, supports inputting document URL. - zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 - llm_description: 多维表格的唯一标识符,支持输入文档 URL。 - form: llm - - - name: table_id - type: string - required: false - label: - en_US: table_id - zh_Hans: table_id - human_description: - en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. - zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 - llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 - form: llm - - - name: table_name - type: string - required: false - label: - en_US: table_name - zh_Hans: table_name - human_description: - en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. - zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 - llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 - form: llm - - - name: record_ids - type: string - required: true - label: - en_US: record_ids - zh_Hans: 记录 ID 列表 - human_description: - en_US: List of record IDs, which can be obtained by calling the "Query Records API". - zh_Hans: 记录 ID 列表,可以通过调用"查询记录接口"获取。 - llm_description: 记录 ID 列表,可以通过调用"查询记录接口"获取。 - form: llm - - - name: user_id_type - type: select - required: false - options: - - value: open_id - label: - en_US: open_id - zh_Hans: open_id - - value: union_id - label: - en_US: union_id - zh_Hans: union_id - - value: user_id - label: - en_US: user_id - zh_Hans: user_id - default: "open_id" - label: - en_US: user_id_type - zh_Hans: 用户 ID 类型 - human_description: - en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. - zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 - llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 - form: form diff --git a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py deleted file mode 100644 index c959496735..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.feishu_api_utils import FeishuRequest - - -class SearchRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get("app_id") - app_secret = self.runtime.credentials.get("app_secret") - client = FeishuRequest(app_id, app_secret) - - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - view_id = tool_parameters.get("view_id") - field_names = tool_parameters.get("field_names") - sort = tool_parameters.get("sort") - filters = tool_parameters.get("filter") - page_token = tool_parameters.get("page_token") - automatic_fields = tool_parameters.get("automatic_fields", False) - user_id_type = tool_parameters.get("user_id_type", "open_id") - page_size = tool_parameters.get("page_size", 20) - - res = client.search_record( - app_token, - table_id, - table_name, - view_id, - field_names, - sort, - filters, - page_token, - automatic_fields, - user_id_type, - page_size, - ) - return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/search_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/search_records.yaml deleted file mode 100644 index 6cac4b0524..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/search_records.yaml +++ /dev/null @@ -1,163 +0,0 @@ -identity: - name: search_records - author: Doug Lea - label: - en_US: Search Records - zh_Hans: 查询记录 -description: - human: - en_US: Query records in a multidimensional table, up to 500 rows per query. - zh_Hans: 查询多维表格数据表中的记录,单次最多查询 500 行记录。 - llm: A tool for querying records in a multidimensional table, up to 500 rows per query. (查询多维表格数据表中的记录,单次最多查询 500 行记录) -parameters: - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: app_token - human_description: - en_US: Unique identifier for the multidimensional table, supports inputting document URL. - zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 - llm_description: 多维表格的唯一标识符,支持输入文档 URL。 - form: llm - - - name: table_id - type: string - required: false - label: - en_US: table_id - zh_Hans: table_id - human_description: - en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. - zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 - llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 - form: llm - - - name: table_name - type: string - required: false - label: - en_US: table_name - zh_Hans: table_name - human_description: - en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. - zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 - llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 - form: llm - - - name: view_id - type: string - required: false - label: - en_US: view_id - zh_Hans: 视图唯一标识 - human_description: - en_US: | - Unique identifier for a view in a multidimensional table. It can be found in the URL's query parameter with the key 'view'. For example: https://svi136aogf123.feishu.cn/base/KWC8bYsYXahYqGsTtqectNn9n3e?table=tblE8a2fmBIEflaE&view=vewlkAVpRx. - zh_Hans: 多维表格中视图的唯一标识,可在多维表格的 URL 地址栏中找到,query 参数中 key 为 view 的部分。例如:https://svi136aogf123.feishu.cn/base/KWC8bYsYXahYqGsTtqectNn9n3e?table=tblE8a2fmBIEflaE&view=vewlkAVpRx。 - llm_description: 多维表格中视图的唯一标识,可在多维表格的 URL 地址栏中找到,query 参数中 key 为 view 的部分。例如:https://svi136aogf123.feishu.cn/base/KWC8bYsYXahYqGsTtqectNn9n3e?table=tblE8a2fmBIEflaE&view=vewlkAVpRx。 - form: llm - - - name: field_names - type: string - required: false - label: - en_US: field_names - zh_Hans: 字段名称 - human_description: - en_US: | - Field names to specify which fields to include in the returned records. Example value: ["Field1", "Field2"]. - zh_Hans: 字段名称,用于指定本次查询返回记录中包含的字段。示例值:["字段1","字段2"]。 - llm_description: 字段名称,用于指定本次查询返回记录中包含的字段。示例值:["字段1","字段2"]。 - form: llm - - - name: sort - type: string - required: false - label: - en_US: sort - zh_Hans: 排序条件 - human_description: - en_US: | - Sorting conditions, for example: [{"field_name":"Multiline Text","desc":true}]. - zh_Hans: 排序条件,例如:[{"field_name":"多行文本","desc":true}]。 - llm_description: 排序条件,例如:[{"field_name":"多行文本","desc":true}]。 - form: llm - - - name: filter - type: string - required: false - label: - en_US: filter - zh_Hans: 筛选条件 - human_description: - en_US: Object containing filter information. For details on how to fill in the filter, refer to the record filter parameter guide (https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide). - zh_Hans: 包含条件筛选信息的对象。了解如何填写 filter,参考记录筛选参数填写指南(https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide)。 - llm_description: 包含条件筛选信息的对象。了解如何填写 filter,参考记录筛选参数填写指南(https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide)。 - form: llm - - - name: automatic_fields - type: boolean - required: false - label: - en_US: automatic_fields - zh_Hans: automatic_fields - human_description: - en_US: Whether to return automatically calculated fields. Default is false, meaning they are not returned. - zh_Hans: 是否返回自动计算的字段。默认为 false,表示不返回。 - llm_description: 是否返回自动计算的字段。默认为 false,表示不返回。 - form: form - - - name: user_id_type - type: select - required: false - options: - - value: open_id - label: - en_US: open_id - zh_Hans: open_id - - value: union_id - label: - en_US: union_id - zh_Hans: union_id - - value: user_id - label: - en_US: user_id - zh_Hans: user_id - default: "open_id" - label: - en_US: user_id_type - zh_Hans: 用户 ID 类型 - human_description: - en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. - zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 - llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 - form: form - - - name: page_size - type: number - required: false - default: 20 - label: - en_US: page_size - zh_Hans: 分页大小 - human_description: - en_US: | - Page size, default value: 20, maximum value: 500. - zh_Hans: 分页大小,默认值:20,最大值:500。 - llm_description: 分页大小,默认值:20,最大值:500。 - form: llm - - - name: page_token - type: string - required: false - label: - en_US: page_token - zh_Hans: 分页标记 - human_description: - en_US: | - Page token, leave empty for the first request to start from the beginning; a new page_token will be returned if there are more items in the paginated query results, which can be used for the next traversal. Example value: "tblsRc9GRRXKqhvW". - zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 - llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py deleted file mode 100644 index a7b0363875..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.feishu_api_utils import FeishuRequest - - -class UpdateRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get("app_id") - app_secret = self.runtime.credentials.get("app_secret") - client = FeishuRequest(app_id, app_secret) - - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - records = tool_parameters.get("records") - user_id_type = tool_parameters.get("user_id_type", "open_id") - - res = client.update_records(app_token, table_id, table_name, records, user_id_type) - return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/update_records.yaml deleted file mode 100644 index 68117e7136..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_records.yaml +++ /dev/null @@ -1,91 +0,0 @@ -identity: - name: update_records - author: Doug Lea - label: - en_US: Update Records - zh_Hans: 更新多条记录 -description: - human: - en_US: Update Multiple Records in Multidimensional Table - zh_Hans: 更新多维表格数据表中的多条记录 - llm: A tool for updating multiple records in a multidimensional table. (更新多维表格数据表中的多条记录) -parameters: - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: app_token - human_description: - en_US: Unique identifier for the multidimensional table, supports inputting document URL. - zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 - llm_description: 多维表格的唯一标识符,支持输入文档 URL。 - form: llm - - - name: table_id - type: string - required: false - label: - en_US: table_id - zh_Hans: table_id - human_description: - en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. - zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 - llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 - form: llm - - - name: table_name - type: string - required: false - label: - en_US: table_name - zh_Hans: table_name - human_description: - en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. - zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 - llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 - form: llm - - - name: records - type: string - required: true - label: - en_US: records - zh_Hans: 记录列表 - human_description: - en_US: | - List of records to be updated in this request. Example value: [{"fields":{"multi-line-text":"text content","single_select":"option 1","date":1674206443000},"record_id":"recupK4f4RM5RX"}]. - For supported field types, refer to the integration guide (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification). For data structures of different field types, refer to the data structure overview (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure). - zh_Hans: | - 本次请求将要更新的记录列表,示例值:[{"fields":{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000},"record_id":"recupK4f4RM5RX"}]。 - 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 - llm_description: | - 本次请求将要更新的记录列表,示例值:[{"fields":{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000},"record_id":"recupK4f4RM5RX"}]。 - 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 - form: llm - - - name: user_id_type - type: select - required: false - options: - - value: open_id - label: - en_US: open_id - zh_Hans: open_id - - value: union_id - label: - en_US: union_id - zh_Hans: union_id - - value: user_id - label: - en_US: user_id - zh_Hans: user_id - default: "open_id" - label: - en_US: user_id_type - zh_Hans: 用户 ID 类型 - human_description: - en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. - zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 - llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 - form: form diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py deleted file mode 100644 index db43790c06..0000000000 --- a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import Any, Union - -import requests - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - -SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/image/generations" - -SD_MODELS = { - "sd_3": "stabilityai/stable-diffusion-3-medium", - "sd_xl": "stabilityai/stable-diffusion-xl-base-1.0", - "sd_3.5_large": "stabilityai/stable-diffusion-3-5-large", -} - - -class StableDiffusionTool(BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - headers = { - "accept": "application/json", - "content-type": "application/json", - "authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}", - } - - model = tool_parameters.get("model", "sd_3") - sd_model = SD_MODELS.get(model) - - payload = { - "model": sd_model, - "prompt": tool_parameters.get("prompt"), - "negative_prompt": tool_parameters.get("negative_prompt", ""), - "image_size": tool_parameters.get("image_size", "1024x1024"), - "batch_size": tool_parameters.get("batch_size", 1), - "seed": tool_parameters.get("seed"), - "guidance_scale": tool_parameters.get("guidance_scale", 7.5), - "num_inference_steps": tool_parameters.get("num_inference_steps", 20), - } - - response = requests.post(SILICONFLOW_API_URL, json=payload, headers=headers) - if response.status_code != 200: - return self.create_text_message(f"Got Error Response:{response.text}") - - res = response.json() - result = [self.create_json_message(res)] - for image in res.get("images", []): - result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) - return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml deleted file mode 100644 index b330c92e16..0000000000 --- a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml +++ /dev/null @@ -1,124 +0,0 @@ -identity: - name: stable_diffusion - author: hjlarry - label: - en_US: Stable Diffusion - icon: icon.svg -description: - human: - en_US: Generate image via SiliconFlow's stable diffusion model. - llm: This tool is used to generate image from prompt via SiliconFlow's stable diffusion model. -parameters: - - name: prompt - type: string - required: true - label: - en_US: prompt - zh_Hans: 提示词 - human_description: - en_US: The text prompt used to generate the image. - zh_Hans: 用于生成图片的文字提示词 - llm_description: this prompt text will be used to generate image. - form: llm - - name: negative_prompt - type: string - label: - en_US: negative prompt - zh_Hans: 负面提示词 - human_description: - en_US: Describe what you don't want included in the image. - zh_Hans: 描述您不希望包含在图片中的内容。 - llm_description: Describe what you don't want included in the image. - form: llm - - name: model - type: select - required: true - options: - - value: sd_3 - label: - en_US: Stable Diffusion 3 - - value: sd_xl - label: - en_US: Stable Diffusion XL - - value: sd_3.5_large - label: - en_US: Stable Diffusion 3.5 Large - default: sd_3 - label: - en_US: Choose Image Model - zh_Hans: 选择生成图片的模型 - form: form - - name: image_size - type: select - required: true - options: - - value: 1024x1024 - label: - en_US: 1024x1024 - - value: 1024x2048 - label: - en_US: 1024x2048 - - value: 1152x2048 - label: - en_US: 1152x2048 - - value: 1536x1024 - label: - en_US: 1536x1024 - - value: 1536x2048 - label: - en_US: 1536x2048 - - value: 2048x1152 - label: - en_US: 2048x1152 - default: 1024x1024 - label: - en_US: Choose Image Size - zh_Hans: 选择生成图片的大小 - form: form - - name: batch_size - type: number - required: true - default: 1 - min: 1 - max: 4 - label: - en_US: Number Images - zh_Hans: 生成图片的数量 - form: form - - name: guidance_scale - type: number - required: true - default: 7.5 - min: 0 - max: 100 - label: - en_US: Guidance Scale - zh_Hans: 与提示词紧密性 - human_description: - en_US: Classifier Free Guidance. How close you want the model to stick to your prompt when looking for a related image to show you. - zh_Hans: 无分类器引导。您希望模型在寻找相关图片向您展示时,与您的提示保持多紧密的关联度。 - form: form - - name: num_inference_steps - type: number - required: true - default: 20 - min: 1 - max: 100 - label: - en_US: Num Inference Steps - zh_Hans: 生成图片的步数 - human_description: - en_US: The number of inference steps to perform. More steps produce higher quality but take longer. - zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。 - form: form - - name: seed - type: number - min: 0 - max: 9999999999 - label: - en_US: Seed - zh_Hans: 种子 - human_description: - en_US: The same seed and prompt can produce similar images. - zh_Hans: 相同的种子和提示可以产生相似的图像。 - form: form diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py deleted file mode 100644 index c722cd36c8..0000000000 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Any, Union - -from httpx import post - -from core.file.enums import FileType -from core.file.file_manager import download -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.errors import ToolParameterValidationError -from core.tools.tool.builtin_tool import BuiltinTool - - -class VectorizerTool(BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - """ - invoke tools - """ - api_key_name = self.runtime.credentials.get("api_key_name") - api_key_value = self.runtime.credentials.get("api_key_value") - mode = tool_parameters.get("mode", "test") - - # image file for workflow mode - image = tool_parameters.get("image") - if image and image.type != FileType.IMAGE: - raise ToolParameterValidationError("Not a valid image") - # image_id for agent mode - image_id = tool_parameters.get("image_id", "") - - if image_id: - image_binary = self.get_variable_file(self.VariableKey.IMAGE) - if not image_binary: - return self.create_text_message("Image not found, please request user to generate image firstly.") - elif image: - image_binary = download(image) - else: - raise ToolParameterValidationError("Please provide either image or image_id") - - response = post( - "https://vectorizer.ai/api/v1/vectorize", - data={"mode": mode}, - files={"image": image_binary}, - auth=(api_key_name, api_key_value), - timeout=30, - ) - - if response.status_code != 200: - raise Exception(response.text) - - return [ - self.create_text_message("the vectorized svg is saved as an image."), - self.create_blob_message(blob=response.content, meta={"mime_type": "image/svg+xml"}), - ] - - def get_runtime_parameters(self) -> list[ToolParameter]: - """ - override the runtime parameters - """ - return [ - ToolParameter.get_simple_instance( - name="image_id", - llm_description=f"the image_id that you want to vectorize, \ - and the image_id should be specified in \ - {[i.name for i in self.list_default_image_variables()]}", - type=ToolParameter.ToolParameterType.SELECT, - required=False, - options=[i.name for i in self.list_default_image_variables()], - ), - ToolParameter( - name="image", - label=I18nObject(en_US="image", zh_Hans="image"), - human_description=I18nObject( - en_US="The image to be converted.", - zh_Hans="要转换的图片。", - ), - type=ToolParameter.ToolParameterType.FILE, - form=ToolParameter.ToolParameterForm.LLM, - llm_description="you should not input this parameter. just input the image_id.", - required=False, - ), - ] diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml deleted file mode 100644 index 0afd1c201f..0000000000 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml +++ /dev/null @@ -1,41 +0,0 @@ -identity: - name: vectorizer - author: Dify - label: - en_US: Vectorizer.AI - zh_Hans: Vectorizer.AI -description: - human: - en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. - zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。 - llm: A tool for converting images to SVG vectors. you should input the image id as the input of this tool. the image id can be got from parameters. -parameters: - - name: image - type: file - label: - en_US: image - human_description: - en_US: The image to be converted. - zh_Hans: 要转换的图片。 - llm_description: you should not input this parameter. just input the image_id. - form: llm - - name: mode - type: select - required: true - options: - - value: production - label: - en_US: production - zh_Hans: 生产模式 - - value: test - label: - en_US: test - zh_Hans: 测试模式 - default: test - label: - en_US: Mode - zh_Hans: 模式 - human_description: - en_US: It is free to integrate with and test out the API in test mode, no subscription required. - zh_Hans: 在测试模式下,可以免费测试API。 - form: form diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py deleted file mode 100644 index 8140348723..0000000000 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Any - -from core.file import File -from core.file.enums import FileTransferMethod, FileType -from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool -from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController - - -class VectorizerProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict[str, Any]) -> None: - test_img = File( - tenant_id="__test_123", - remote_url="https://cloud.dify.ai/logo/logo-site.png", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.REMOTE_URL, - ) - try: - VectorizerTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id="", - tool_parameters={"mode": "test", "image": test_img}, - ) - except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml b/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml deleted file mode 100644 index 94dae20876..0000000000 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml +++ /dev/null @@ -1,39 +0,0 @@ -identity: - author: Dify - name: vectorizer - label: - en_US: Vectorizer.AI - zh_Hans: Vectorizer.AI - description: - en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. - zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。 - icon: icon.png - tags: - - productivity - - image -credentials_for_provider: - api_key_name: - type: secret-input - required: true - label: - en_US: Vectorizer.AI API Key name - zh_Hans: Vectorizer.AI API Key name - placeholder: - en_US: Please input your Vectorizer.AI ApiKey name - zh_Hans: 请输入你的 Vectorizer.AI ApiKey name - help: - en_US: Get your Vectorizer.AI API Key from Vectorizer.AI. - zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。 - url: https://vectorizer.ai/api - api_key_value: - type: secret-input - required: true - label: - en_US: Vectorizer.AI API Key - zh_Hans: Vectorizer.AI API Key - placeholder: - en_US: Please input your Vectorizer.AI ApiKey - zh_Hans: 请输入你的 Vectorizer.AI ApiKey - help: - en_US: Get your Vectorizer.AI API Key from Vectorizer.AI. - zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。 diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 9e09b6d29a..c2f51ad1e5 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,6 +5,7 @@ import json import docx import pandas as pd import pypdfium2 +import yaml from unstructured.partition.email import partition_email from unstructured.partition.epub import partition_epub from unstructured.partition.msg import partition_msg @@ -101,6 +102,8 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: return _extract_text_from_msg(file_content) case "application/json": return _extract_text_from_json(file_content) + case "application/x-yaml" | "text/yaml": + return _extract_text_from_yaml(file_content) case _: raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") @@ -112,6 +115,8 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) return _extract_text_from_plain_text(file_content) case ".json": return _extract_text_from_json(file_content) + case ".yaml" | ".yml": + return _extract_text_from_yaml(file_content) case ".pdf": return _extract_text_from_pdf(file_content) case ".doc" | ".docx": @@ -149,6 +154,15 @@ def _extract_text_from_json(file_content: bytes) -> str: raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e +def _extract_text_from_yaml(file_content: bytes) -> str: + """Extract the content from yaml file""" + try: + yaml_data = yaml.safe_load_all(file_content.decode("utf-8")) + return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) + except (UnicodeDecodeError, yaml.YAMLError) as e: + raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e + + def _extract_text_from_pdf(file_content: bytes) -> str: try: pdf_file = io.BytesIO(file_content) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 472587cb03..b4728e6abf 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -349,13 +349,11 @@ class LLMNode(BaseNode[LLMNodeData]): variable = self.graph_runtime_state.variable_pool.get(selector) if variable is None: return [] - if isinstance(variable, FileSegment): + elif isinstance(variable, FileSegment): return [variable.value] - if isinstance(variable, ArrayFileSegment): + elif isinstance(variable, ArrayFileSegment): return variable.value - # FIXME: Temporary fix for empty array, - # all variables added to variable pool should be a Segment instance. - if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0: + elif isinstance(variable, NoneSegment | ArrayAnySegment): return [] raise ValueError(f"Invalid variable type: {type(variable)}") diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 0fa832f420..56b1d6bd28 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -1,10 +1,8 @@ import logging import os import sys -from datetime import datetime from logging.handlers import RotatingFileHandler -import pytz from flask import Flask from configs import dify_config @@ -32,10 +30,16 @@ def init_app(app: Flask): handlers=log_handlers, force=True, ) - log_tz = dify_config.LOG_TZ if log_tz: + from datetime import datetime + + import pytz + + timezone = pytz.timezone(log_tz) + + def time_converter(seconds): + return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() + for handler in logging.root.handlers: - handler.formatter.converter = lambda seconds: ( - datetime.fromtimestamp(seconds, tz=pytz.UTC).astimezone(log_tz).timetuple() - ) + handler.formatter.converter = time_converter diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index ead7b9a8b3..1066dc8862 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -160,7 +160,7 @@ def _build_from_local_file( tenant_id=tenant_id, type=file_type, transfer_method=transfer_method, - remote_url=None, + remote_url=row.source_url, related_id=mapping.get("upload_file_id"), _extra_config=config, size=row.size, diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index a758f9981f..d0c8c7e84f 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -10,6 +10,7 @@ from core.variables import ( ArrayNumberVariable, ArrayObjectSegment, ArrayObjectVariable, + ArraySegment, ArrayStringSegment, ArrayStringVariable, FileSegment, @@ -79,7 +80,7 @@ def build_segment(value: Any, /) -> Segment: if isinstance(value, list): items = [build_segment(item) for item in value] types = {item.value_type for item in items} - if len(types) != 1: + if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items): return ArrayAnySegment(value=value) match types.pop(): case SegmentType.STRING: diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index bf1c491a05..2eb19c2667 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -121,6 +121,7 @@ conversation_fields = { "from_account_name": fields.String, "read_at": TimestampField, "created_at": TimestampField, + "updated_at": TimestampField, "annotation": fields.Nested(annotation_fields, allow_null=True), "model_config": fields.Nested(simple_model_config_fields), "user_feedback_stats": fields.Nested(feedback_stat_fields), @@ -182,6 +183,7 @@ conversation_detail_fields = { "from_end_user_id": fields.String, "from_account_id": fields.String, "created_at": TimestampField, + "updated_at": TimestampField, "annotated": fields.Boolean, "introduction": fields.String, "model_config": fields.Nested(model_config_fields), @@ -197,6 +199,7 @@ simple_conversation_fields = { "status": fields.String, "introduction": fields.String, "created_at": TimestampField, + "updated_at": TimestampField, } conversation_infinite_scroll_pagination_fields = { diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 9ff1111b74..1cddc24b2c 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -24,3 +24,15 @@ remote_file_info_fields = { "file_type": fields.String(attribute="file_type"), "file_length": fields.Integer(attribute="file_length"), } + + +file_fields_with_signed_url = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "url": fields.String, + "mime_type": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} diff --git a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py index 6a7402b16a..153861a71a 100644 --- a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py +++ b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py @@ -28,16 +28,12 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), sa.PrimaryKeyConstraint('id', name='tracing_app_config_pkey') ) - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.create_index('tracing_app_config_app_id_idx', ['app_id'], unique=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ## - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') - op.drop_table('tracing_app_configs') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py new file mode 100644 index 0000000000..a749c8bddf --- /dev/null +++ b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py @@ -0,0 +1,31 @@ +"""Add upload_files.source_url + +Revision ID: d3f6769a94a3 +Revises: 43fa78bc3b7d +Create Date: 2024-11-01 04:34:23.816198 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'd3f6769a94a3' +down_revision = '43fa78bc3b7d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('source_url', sa.String(length=255), server_default='', nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.drop_column('source_url') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py b/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py new file mode 100644 index 0000000000..81a7978f73 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py @@ -0,0 +1,52 @@ +"""rename conversation variables index name + +Revision ID: 93ad8c19c40b +Revises: d3f6769a94a3 +Create Date: 2024-11-01 04:49:53.100250 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '93ad8c19c40b' +down_revision = 'd3f6769a94a3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if conn.dialect.name == 'postgresql': + # Rename indexes for PostgreSQL + op.execute('ALTER INDEX workflow__conversation_variables_app_id_idx RENAME TO workflow_conversation_variables_app_id_idx') + op.execute('ALTER INDEX workflow__conversation_variables_created_at_idx RENAME TO workflow_conversation_variables_created_at_idx') + else: + # For other databases, use the original drop and create method + with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op: + batch_op.drop_index('workflow__conversation_variables_app_id_idx') + batch_op.drop_index('workflow__conversation_variables_created_at_idx') + batch_op.create_index(batch_op.f('workflow_conversation_variables_app_id_idx'), ['app_id'], unique=False) + batch_op.create_index(batch_op.f('workflow_conversation_variables_created_at_idx'), ['created_at'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if conn.dialect.name == 'postgresql': + # Rename indexes back for PostgreSQL + op.execute('ALTER INDEX workflow_conversation_variables_app_id_idx RENAME TO workflow__conversation_variables_app_id_idx') + op.execute('ALTER INDEX workflow_conversation_variables_created_at_idx RENAME TO workflow__conversation_variables_created_at_idx') + else: + # For other databases, use the original drop and create method + with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow_conversation_variables_created_at_idx')) + batch_op.drop_index(batch_op.f('workflow_conversation_variables_app_id_idx')) + batch_op.create_index('workflow__conversation_variables_created_at_idx', ['created_at'], unique=False) + batch_op.create_index('workflow__conversation_variables_app_id_idx', ['app_id'], unique=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py new file mode 100644 index 0000000000..222379a490 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py @@ -0,0 +1,41 @@ +"""update upload_files.source_url + +Revision ID: f4d7ce70a7ca +Revises: 93ad8c19c40b +Create Date: 2024-11-01 05:40:03.531751 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'f4d7ce70a7ca' +down_revision = '93ad8c19c40b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.alter_column('source_url', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + existing_nullable=False, + existing_server_default=sa.text("''::character varying")) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.alter_column('source_url', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + existing_nullable=False, + existing_server_default=sa.text("''::character varying")) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py new file mode 100644 index 0000000000..9a4ccf352d --- /dev/null +++ b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py @@ -0,0 +1,67 @@ +"""update type of custom_disclaimer to TEXT + +Revision ID: d07474999927 +Revises: f4d7ce70a7ca +Create Date: 2024-11-01 06:22:27.981398 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'd07474999927' +down_revision = 'f4d7ce70a7ca' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.execute("UPDATE recommended_apps SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py new file mode 100644 index 0000000000..117a7351cd --- /dev/null +++ b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py @@ -0,0 +1,73 @@ +"""update workflows graph, features and updated_at + +Revision ID: 09a8d1878d9b +Revises: d07474999927 +Create Date: 2024-11-01 06:23:59.579186 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '09a8d1878d9b' +down_revision = 'd07474999927' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + + op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL") + op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL") + op.execute("UPDATE workflows SET features = '' WHERE features IS NULL") + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('graph', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('features', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('updated_at', + existing_type=postgresql.TIMESTAMP(), + nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('updated_at', + existing_type=postgresql.TIMESTAMP(), + nullable=True) + batch_op.alter_column('features', + existing_type=sa.TEXT(), + nullable=True) + batch_op.alter_column('graph', + existing_type=sa.TEXT(), + nullable=True) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py index 09ef5e186c..99b7010612 100644 --- a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py +++ b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py @@ -22,17 +22,11 @@ def upgrade(): with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True)) - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.create_index('tracing_app_config_app_id_idx', ['app_id'], unique=False) - # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') - with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.drop_column('tracing') diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py index 469c04338a..f87819c367 100644 --- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py +++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py @@ -30,30 +30,15 @@ def upgrade(): sa.Column('is_active', sa.Boolean(), server_default=sa.text('true'), nullable=False), sa.PrimaryKeyConstraint('id', name='trace_app_config_pkey') ) + with op.batch_alter_table('trace_app_config', schema=None) as batch_op: batch_op.create_index('trace_app_config_app_id_idx', ['app_id'], unique=False) - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('tracing_app_configs', - sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), - sa.Column('app_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('tracing_provider', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('tracing_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False), - sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('id', name='trace_app_config_pkey') - ) - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.create_index('trace_app_config_app_id_idx', ['app_id'], unique=False) - - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.drop_index('trace_app_config_app_id_idx') - op.drop_table('trace_app_config') + # ### end Alembic commands ### diff --git a/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py b/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py index 271b2490de..6f76a361d9 100644 --- a/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py +++ b/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py @@ -20,12 +20,10 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.drop_table('tracing_app_configs') - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') - # idx_dataset_permissions_tenant_id with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: batch_op.create_index('idx_dataset_permissions_tenant_id', ['tenant_id']) + # ### end Alembic commands ### @@ -46,9 +44,7 @@ def downgrade(): sa.PrimaryKeyConstraint('id', name='tracing_app_config_pkey') ) - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.create_index('tracing_app_config_app_id_idx', ['app_id']) - with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: batch_op.drop_index('idx_dataset_permissions_tenant_id') + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 27b1e5e61f..8a619d3f30 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from typing import Any, Literal +import sqlalchemy as sa from flask import request from flask_login import UserMixin from pydantic import BaseModel, Field @@ -406,7 +407,7 @@ class AppModelConfig(Base): "file_upload": self.file_upload_dict, } - def from_model_config_dict(self, model_config: dict): + def from_model_config_dict(self, model_config: Mapping[str, Any]): self.opening_statement = model_config.get("opening_statement") self.suggested_questions = ( json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None @@ -493,7 +494,7 @@ class RecommendedApp(Base): description = db.Column(db.JSON, nullable=False) copyright = db.Column(db.String(255), nullable=False) privacy_policy = db.Column(db.String(255), nullable=False) - custom_disclaimer = db.Column(db.String(255), nullable=True) + custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") category = db.Column(db.String(255), nullable=False) position = db.Column(db.Integer, nullable=False, default=0) is_listed = db.Column(db.Boolean, nullable=False, default=True) @@ -1319,7 +1320,7 @@ class Site(Base): privacy_policy = db.Column(db.String(255)) show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - custom_disclaimer = db.Column(db.String(255), nullable=True) + custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") customize_domain = db.Column(db.String(255)) customize_token_strategy = db.Column(db.String(255), nullable=False) prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @@ -1397,6 +1398,7 @@ class UploadFile(Base): used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) hash: Mapped[str | None] = db.Column(db.String(255), nullable=True) + source_url: Mapped[str] = mapped_column(sa.TEXT, default="") def __init__( self, @@ -1415,7 +1417,8 @@ class UploadFile(Base): used_by: str | None = None, used_at: datetime | None = None, hash: str | None = None, - ) -> None: + source_url: str = "", + ): self.tenant_id = tenant_id self.storage_type = storage_type self.key = key @@ -1430,6 +1433,7 @@ class UploadFile(Base): self.used_by = used_by self.used_at = used_at self.hash = hash + self.source_url = source_url class ApiRequest(Base): diff --git a/api/models/tools.py b/api/models/tools.py index 869dd0201f..248e28e0b9 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -2,6 +2,7 @@ import json from datetime import datetime from typing import Optional +import sqlalchemy as sa from deprecated import deprecated from sqlalchemy import ForeignKey from sqlalchemy.orm import Mapped, mapped_column @@ -82,7 +83,7 @@ class ApiToolProvider(Base): # privacy policy privacy_policy = db.Column(db.String(255), nullable=True) # custom_disclaimer - custom_disclaimer = db.Column(db.String(255), nullable=True) + custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/workflow.py b/api/models/workflow.py index da3152ec75..bc4434ae5a 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,11 +2,12 @@ import json from collections.abc import Mapping, Sequence from datetime import datetime from enum import Enum -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Optional, Union if TYPE_CHECKING: from models.model import AppMode +import sqlalchemy as sa from sqlalchemy import Index, PrimaryKeyConstraint, func from sqlalchemy.orm import Mapped, mapped_column @@ -103,14 +104,14 @@ class Workflow(Base): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(db.String(255), nullable=False) version: Mapped[str] = mapped_column(db.String(255), nullable=False) - graph: Mapped[str] = mapped_column(db.Text) - _features: Mapped[str] = mapped_column("features") + graph: Mapped[str] = mapped_column(sa.Text) + _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) - updated_by: Mapped[str] = mapped_column(StringUUID) - updated_at: Mapped[datetime] = mapped_column(db.DateTime) + updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" ) diff --git a/api/poetry.lock b/api/poetry.lock index 233572ebfb..d7af124794 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -932,10 +932,6 @@ files = [ {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"}, {file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"}, {file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"}, {file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"}, @@ -948,14 +944,8 @@ files = [ {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"}, {file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"}, {file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"}, - {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"}, - {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"}, {file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"}, @@ -966,24 +956,8 @@ files = [ {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"}, {file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"}, {file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"}, - {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"}, - {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"}, - {file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"}, - {file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"}, {file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"}, @@ -993,10 +967,6 @@ files = [ {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"}, {file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"}, {file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"}, {file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"}, @@ -1008,10 +978,6 @@ files = [ {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"}, {file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"}, {file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"}, {file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"}, @@ -1024,10 +990,6 @@ files = [ {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"}, {file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"}, {file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"}, {file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"}, @@ -1040,10 +1002,6 @@ files = [ {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"}, {file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"}, {file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"}, {file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"}, diff --git a/api/services/app_dsl_service/__init__.py b/api/services/app_dsl_service/__init__.py new file mode 100644 index 0000000000..9fc988ffb3 --- /dev/null +++ b/api/services/app_dsl_service/__init__.py @@ -0,0 +1,3 @@ +from .service import AppDslService + +__all__ = ["AppDslService"] diff --git a/api/services/app_dsl_service/exc.py b/api/services/app_dsl_service/exc.py new file mode 100644 index 0000000000..6da4b1938f --- /dev/null +++ b/api/services/app_dsl_service/exc.py @@ -0,0 +1,34 @@ +class DSLVersionNotSupportedError(ValueError): + """Raised when the imported DSL version is not supported by the current Dify version.""" + + +class InvalidYAMLFormatError(ValueError): + """Raised when the provided YAML format is invalid.""" + + +class MissingAppDataError(ValueError): + """Raised when the app data is missing in the provided DSL.""" + + +class InvalidAppModeError(ValueError): + """Raised when the app mode is invalid.""" + + +class MissingWorkflowDataError(ValueError): + """Raised when the workflow data is missing in the provided DSL.""" + + +class MissingModelConfigError(ValueError): + """Raised when the model config data is missing in the provided DSL.""" + + +class FileSizeLimitExceededError(ValueError): + """Raised when the file size exceeds the allowed limit.""" + + +class EmptyContentError(ValueError): + """Raised when the content fetched from the URL is empty.""" + + +class ContentDecodingError(ValueError): + """Raised when there is an error decoding the content.""" diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service/service.py similarity index 75% rename from api/services/app_dsl_service.py rename to api/services/app_dsl_service/service.py index 750d0a8cd2..32b95ae3aa 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service/service.py @@ -1,8 +1,11 @@ import logging +from collections.abc import Mapping +from typing import Any -import httpx -import yaml # type: ignore +import yaml +from packaging import version +from core.helper import ssrf_proxy from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_database import db from factories import variable_factory @@ -11,6 +14,17 @@ from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow from services.workflow_service import WorkflowService +from .exc import ( + ContentDecodingError, + EmptyContentError, + FileSizeLimitExceededError, + InvalidAppModeError, + InvalidYAMLFormatError, + MissingAppDataError, + MissingModelConfigError, + MissingWorkflowDataError, +) + logger = logging.getLogger(__name__) current_dsl_version = "0.1.2" @@ -30,32 +44,21 @@ class AppDslService: :param args: request args :param account: Account instance """ - try: - max_size = 10 * 1024 * 1024 # 10MB - timeout = httpx.Timeout(10.0) - with httpx.stream("GET", url.strip(), follow_redirects=True, timeout=timeout) as response: - response.raise_for_status() - total_size = 0 - content = b"" - for chunk in response.iter_bytes(): - total_size += len(chunk) - if total_size > max_size: - raise ValueError("File size exceeds the limit of 10MB") - content += chunk - except httpx.HTTPStatusError as http_err: - raise ValueError(f"HTTP error occurred: {http_err}") - except httpx.RequestError as req_err: - raise ValueError(f"Request error occurred: {req_err}") - except Exception as e: - raise ValueError(f"Failed to fetch DSL from URL: {e}") + max_size = 10 * 1024 * 1024 # 10MB + response = ssrf_proxy.get(url.strip(), follow_redirects=True, timeout=(10, 10)) + response.raise_for_status() + content = response.content + + if len(content) > max_size: + raise FileSizeLimitExceededError("File size exceeds the limit of 10MB") if not content: - raise ValueError("Empty content from url") + raise EmptyContentError("Empty content from url") try: data = content.decode("utf-8") except UnicodeDecodeError as e: - raise ValueError(f"Error decoding content: {e}") + raise ContentDecodingError(f"Error decoding content: {e}") return cls.import_and_create_new_app(tenant_id, data, args, account) @@ -71,14 +74,14 @@ class AppDslService: try: import_data = yaml.safe_load(data) except yaml.YAMLError: - raise ValueError("Invalid YAML format in data argument.") + raise InvalidYAMLFormatError("Invalid YAML format in data argument.") # check or repair dsl version - import_data = cls._check_or_fix_dsl(import_data) + import_data = _check_or_fix_dsl(import_data) app_data = import_data.get("app") if not app_data: - raise ValueError("Missing app in data argument") + raise MissingAppDataError("Missing app in data argument") # get app basic info name = args.get("name") or app_data.get("name") @@ -90,11 +93,18 @@ class AppDslService: # import dsl and create app app_mode = AppMode.value_of(app_data.get("mode")) + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow_data = import_data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) + app = cls._import_and_create_new_workflow_based_app( tenant_id=tenant_id, app_mode=app_mode, - workflow_data=import_data.get("workflow"), + workflow_data=workflow_data, account=account, name=name, description=description, @@ -104,10 +114,16 @@ class AppDslService: use_icon_as_answer_icon=use_icon_as_answer_icon, ) elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: + model_config = import_data.get("model_config") + if not model_config or not isinstance(model_config, dict): + raise MissingModelConfigError( + "Missing model_config in data argument when app mode is chat, agent-chat or completion" + ) + app = cls._import_and_create_new_model_config_based_app( tenant_id=tenant_id, app_mode=app_mode, - model_config_data=import_data.get("model_config"), + model_config_data=model_config, account=account, name=name, description=description, @@ -117,7 +133,7 @@ class AppDslService: use_icon_as_answer_icon=use_icon_as_answer_icon, ) else: - raise ValueError("Invalid app mode") + raise InvalidAppModeError("Invalid app mode") return app @@ -132,26 +148,32 @@ class AppDslService: try: import_data = yaml.safe_load(data) except yaml.YAMLError: - raise ValueError("Invalid YAML format in data argument.") + raise InvalidYAMLFormatError("Invalid YAML format in data argument.") # check or repair dsl version - import_data = cls._check_or_fix_dsl(import_data) + import_data = _check_or_fix_dsl(import_data) app_data = import_data.get("app") if not app_data: - raise ValueError("Missing app in data argument") + raise MissingAppDataError("Missing app in data argument") # import dsl and overwrite app app_mode = AppMode.value_of(app_data.get("mode")) if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - raise ValueError("Only support import workflow in advanced-chat or workflow app.") + raise InvalidAppModeError("Only support import workflow in advanced-chat or workflow app.") if app_data.get("mode") != app_model.mode: raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") + workflow_data = import_data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) + return cls._import_and_overwrite_workflow_based_app( app_model=app_model, - workflow_data=import_data.get("workflow"), + workflow_data=workflow_data, account=account, ) @@ -186,35 +208,12 @@ class AppDslService: return yaml.dump(export_data, allow_unicode=True) - @classmethod - def _check_or_fix_dsl(cls, import_data: dict) -> dict: - """ - Check or fix dsl - - :param import_data: import data - """ - if not import_data.get("version"): - import_data["version"] = "0.1.0" - - if not import_data.get("kind") or import_data.get("kind") != "app": - import_data["kind"] = "app" - - if import_data.get("version") != current_dsl_version: - # Currently only one DSL version, so no difference checks or compatibility fixes will be performed. - logger.warning( - f"DSL version {import_data.get('version')} is not compatible " - f"with current version {current_dsl_version}, related to " - f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}." - ) - - return import_data - @classmethod def _import_and_create_new_workflow_based_app( cls, tenant_id: str, app_mode: AppMode, - workflow_data: dict, + workflow_data: Mapping[str, Any], account: Account, name: str, description: str, @@ -238,7 +237,9 @@ class AppDslService: :param use_icon_as_answer_icon: use app icon as answer icon """ if not workflow_data: - raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow") + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) app = cls._create_app( tenant_id=tenant_id, @@ -277,7 +278,7 @@ class AppDslService: @classmethod def _import_and_overwrite_workflow_based_app( - cls, app_model: App, workflow_data: dict, account: Account + cls, app_model: App, workflow_data: Mapping[str, Any], account: Account ) -> Workflow: """ Import app dsl and overwrite workflow based app @@ -287,7 +288,9 @@ class AppDslService: :param account: Account instance """ if not workflow_data: - raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow") + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) # fetch draft workflow by app_model workflow_service = WorkflowService() @@ -323,7 +326,7 @@ class AppDslService: cls, tenant_id: str, app_mode: AppMode, - model_config_data: dict, + model_config_data: Mapping[str, Any], account: Account, name: str, description: str, @@ -345,7 +348,9 @@ class AppDslService: :param icon_background: app icon background """ if not model_config_data: - raise ValueError("Missing model_config in data argument when app mode is chat, agent-chat or completion") + raise MissingModelConfigError( + "Missing model_config in data argument when app mode is chat, agent-chat or completion" + ) app = cls._create_app( tenant_id=tenant_id, @@ -448,3 +453,36 @@ class AppDslService: raise ValueError("Missing app configuration, please check.") export_data["model_config"] = app_model_config.to_dict() + + +def _check_or_fix_dsl(import_data: dict[str, Any]) -> Mapping[str, Any]: + """ + Check or fix dsl + + :param import_data: import data + :raises DSLVersionNotSupportedError: if the imported DSL version is newer than the current version + """ + if not import_data.get("version"): + import_data["version"] = "0.1.0" + + if not import_data.get("kind") or import_data.get("kind") != "app": + import_data["kind"] = "app" + + imported_version = import_data.get("version") + if imported_version != current_dsl_version: + if imported_version and version.parse(imported_version) > version.parse(current_dsl_version): + errmsg = ( + f"The imported DSL version {imported_version} is newer than " + f"the current supported version {current_dsl_version}. " + f"Please upgrade your Dify instance to import this configuration." + ) + logger.warning(errmsg) + # raise DSLVersionNotSupportedError(errmsg) + else: + logger.warning( + f"DSL version {imported_version} is older than " + f"the current version {current_dsl_version}. " + f"This may cause compatibility issues." + ) + + return import_data diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 414ef0224a..ac05cbc4f5 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -4,7 +4,7 @@ import logging import random import time import uuid -from typing import Optional +from typing import Any, Optional from flask_login import current_user from sqlalchemy import func @@ -675,7 +675,7 @@ class DocumentService: def save_document_with_dataset_id( dataset: Dataset, document_data: dict, - account: Account, + account: Account | Any, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", ): @@ -736,11 +736,12 @@ class DocumentService: dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model documents = [] - batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) if document_data.get("original_document_id"): document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) documents.append(document) + batch = document.batch else: + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) # save process rule if not dataset_process_rule: process_rule = document_data["process_rule"] @@ -921,7 +922,7 @@ class DocumentService: if duplicate_document_ids: duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) - return documents, batch + return documents, batch @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): diff --git a/api/services/file_service.py b/api/services/file_service.py index 6193a39669..976111502c 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,10 +1,9 @@ import datetime import hashlib import uuid -from typing import Literal, Union +from typing import Any, Literal, Union from flask_login import current_user -from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound from configs import dify_config @@ -21,7 +20,8 @@ from extensions.ext_storage import storage from models.account import Account from models.enums import CreatedByRole from models.model import EndUser, UploadFile -from services.errors.file import FileNotExistsError, FileTooLargeError, UnsupportedFileTypeError + +from .errors.file import FileTooLargeError, UnsupportedFileTypeError PREVIEW_WORDS_LIMIT = 3000 @@ -29,38 +29,28 @@ PREVIEW_WORDS_LIMIT = 3000 class FileService: @staticmethod def upload_file( - file: FileStorage, user: Union[Account, EndUser], source: Literal["datasets"] | None = None + *, + filename: str, + content: bytes, + mimetype: str, + user: Union[Account, EndUser, Any], + source: Literal["datasets"] | None = None, + source_url: str = "", ) -> UploadFile: - # get file name - filename = file.filename - if not filename: - raise FileNotExistsError - extension = filename.split(".")[-1] + # get file extension + extension = filename.split(".")[-1].lower() if len(filename) > 200: filename = filename.split(".")[0][:200] + "." + extension if source == "datasets" and extension not in DOCUMENT_EXTENSIONS: raise UnsupportedFileTypeError() - # select file size limit - if extension in IMAGE_EXTENSIONS: - file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 - elif extension in VIDEO_EXTENSIONS: - file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 - elif extension in AUDIO_EXTENSIONS: - file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 - else: - file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 - - # read file content - file_content = file.read() # get file size - file_size = len(file_content) + file_size = len(content) # check if the file size is exceeded - if file_size > file_size_limit: - message = f"File size exceeded. {file_size} > {file_size_limit}" - raise FileTooLargeError(message) + if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size): + raise FileTooLargeError # generate file key file_uuid = str(uuid.uuid4()) @@ -74,7 +64,7 @@ class FileService: file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension # save file to storage - storage.save(file_key, file_content) + storage.save(file_key, content) # save file to db upload_file = UploadFile( @@ -84,12 +74,13 @@ class FileService: name=filename, size=file_size, extension=extension, - mime_type=file.mimetype, + mime_type=mimetype, created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), created_by=user.id, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=False, - hash=hashlib.sha3_256(file_content).hexdigest(), + hash=hashlib.sha3_256(content).hexdigest(), + source_url=source_url, ) db.session.add(upload_file) @@ -97,6 +88,19 @@ class FileService: return upload_file + @staticmethod + def is_file_size_within_limit(*, extension: str, file_size: int) -> bool: + if extension in IMAGE_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in VIDEO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in AUDIO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 + else: + file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 + + return file_size <= file_size_limit + @staticmethod def upload_text(text: str, text_name: str) -> UploadFile: if len(text_name) > 200: diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index ba397167b2..fa4a2eb36c 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -91,5 +91,10 @@ INNER_API_KEY= # Marketplace configuration MARKETPLACE_API_URL= +# VESSL AI Credentials +VESSL_AI_MODEL_NAME= +VESSL_AI_API_KEY= +VESSL_AI_ENDPOINT_URL= + # Gitee AI Credentials GITEE_AI_API_KEY= diff --git a/api/tests/integration_tests/model_runtime/vessl_ai/__init__.py b/api/tests/integration_tests/model_runtime/vessl_ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py b/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py new file mode 100644 index 0000000000..7797d0f8e4 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py @@ -0,0 +1,131 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.vessl_ai.llm.llm import VesslAILargeLanguageModel + + +def test_validate_credentials(): + model = VesslAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": "invalid_key", + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + ) + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": "http://invalid_url", + "mode": "chat", + }, + ) + + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + ) + + +def test_invoke_model(): + model = VesslAILargeLanguageModel() + + response = model.invoke( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = VesslAILargeLanguageModel() + + response = model.invoke( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_get_num_tokens(): + model = VesslAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21 diff --git a/api/tests/unit_tests/controllers/test_compare_versions.py b/api/tests/unit_tests/controllers/test_compare_versions.py index 87902b6d44..9db57a8446 100644 --- a/api/tests/unit_tests/controllers/test_compare_versions.py +++ b/api/tests/unit_tests/controllers/test_compare_versions.py @@ -22,17 +22,3 @@ from controllers.console.version import _has_new_version ) def test_has_new_version(latest_version, current_version, expected): assert _has_new_version(latest_version=latest_version, current_version=current_version) == expected - - -def test_has_new_version_invalid_input(): - with pytest.raises(ValueError): - _has_new_version(latest_version="1.0", current_version="1.0.0") - - with pytest.raises(ValueError): - _has_new_version(latest_version="1.0.0", current_version="1.0") - - with pytest.raises(ValueError): - _has_new_version(latest_version="invalid", current_version="1.0.0") - - with pytest.raises(ValueError): - _has_new_version(latest_version="1.0.0", current_version="invalid") diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py new file mode 100644 index 0000000000..def6c2a232 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -0,0 +1,125 @@ +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file import File, FileTransferMethod, FileType +from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState +from core.workflow.nodes.answer import AnswerStreamGenerateRoute +from core.workflow.nodes.end import EndStreamParam +from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions +from core.workflow.nodes.llm.node import LLMNode +from models.enums import UserFrom +from models.workflow import WorkflowType + + +class TestLLMNode: + @pytest.fixture + def llm_node(self): + data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[], + memory=None, + context=ContextConfig(enabled=False), + vision=VisionConfig( + enabled=True, + configs=VisionConfigOptions( + variable_selector=["sys", "files"], + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + node = LLMNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + return node + + def test_fetch_files_with_file_segment(self, llm_node): + file = File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + ) + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [file] + + def test_fetch_files_with_array_file_segment(self, llm_node): + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + ), + File( + id="2", + tenant_id="test", + type=FileType.IMAGE, + filename="test2.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="2", + ), + ] + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == files + + def test_fetch_files_with_none_segment(self, llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + def test_fetch_files_with_array_any_segment(self, llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + def test_fetch_files_with_non_existent_variable(self, llm_node): + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] diff --git a/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py b/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py new file mode 100644 index 0000000000..842e8268d1 --- /dev/null +++ b/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py @@ -0,0 +1,47 @@ +import pytest +from packaging import version + +from services.app_dsl_service import AppDslService +from services.app_dsl_service.exc import DSLVersionNotSupportedError +from services.app_dsl_service.service import _check_or_fix_dsl, current_dsl_version + + +class TestAppDSLService: + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_missing_version(self): + import_data = {} + result = _check_or_fix_dsl(import_data) + assert result["version"] == "0.1.0" + assert result["kind"] == "app" + + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_missing_kind(self): + import_data = {"version": "0.1.0"} + result = _check_or_fix_dsl(import_data) + assert result["kind"] == "app" + + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_older_version(self): + import_data = {"version": "0.0.9", "kind": "app"} + result = _check_or_fix_dsl(import_data) + assert result["version"] == "0.0.9" + + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_current_version(self): + import_data = {"version": current_dsl_version, "kind": "app"} + result = _check_or_fix_dsl(import_data) + assert result["version"] == current_dsl_version + + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_newer_version(self): + current_version = version.parse(current_dsl_version) + newer_version = f"{current_version.major}.{current_version.minor + 1}.0" + import_data = {"version": newer_version, "kind": "app"} + with pytest.raises(DSLVersionNotSupportedError): + _check_or_fix_dsl(import_data) + + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_invalid_kind(self): + import_data = {"version": current_dsl_version, "kind": "invalid"} + result = _check_or_fix_dsl(import_data) + assert result["kind"] == "app" diff --git a/docker/.env.example b/docker/.env.example index ef2f331c11..34b2136302 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -558,6 +558,22 @@ ETL_TYPE=dify # For example: http://unstructured:8000/general/v0/general UNSTRUCTURED_API_URL= +# ------------------------------ +# Model Configuration +# ------------------------------ + +# The maximum number of tokens allowed for prompt generation. +# This setting controls the upper limit of tokens that can be used by the LLM +# when generating a prompt in the prompt generation tool. +# Default: 512 tokens. +PROMPT_GENERATION_MAX_TOKENS=512 + +# The maximum number of tokens allowed for code generation. +# This setting controls the upper limit of tokens that can be used by the LLM +# when generating code in the code generation tool. +# Default: 1024 tokens. +CODE_GENERATION_MAX_TOKENS=1024 + # ------------------------------ # Multi-modal Configuration # ------------------------------ @@ -572,6 +588,12 @@ MULTIMODAL_SEND_IMAGE_FORMAT=base64 # Upload image file size limit, default 10M. UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 +# Upload video file size limit, default 100M. +UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 + +# Upload audio file size limit, default 50M. +UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 + # ------------------------------ # Sentry Configuration # Used for application monitoring and error log tracking. diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 31624285b1..2eea273e72 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -56,6 +56,7 @@ services: SANDBOX_PORT: ${SANDBOX_PORT:-8194} volumes: - ./volumes/sandbox/dependencies:/dependencies + - ./volumes/sandbox/conf:/conf healthcheck: test: [ "CMD", "curl", "-f", "http://localhost:8194/health" ] networks: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 06c99b5eab..112e9a2702 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -207,8 +207,12 @@ x-shared-env: &shared-api-worker-env UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} + PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512} + CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024} MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64} UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10} + UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100} + UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50} SENTRY_DSN: ${API_SENTRY_DSN:-} SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} diff --git a/docker/volumes/sandbox/conf/config.yaml b/docker/volumes/sandbox/conf/config.yaml new file mode 100644 index 0000000000..8c1a1deb54 --- /dev/null +++ b/docker/volumes/sandbox/conf/config.yaml @@ -0,0 +1,14 @@ +app: + port: 8194 + debug: True + key: dify-sandbox +max_workers: 4 +max_requests: 50 +worker_timeout: 5 +python_path: /usr/local/bin/python3 +enable_network: True # please make sure there is no network risk in your environment +allowed_syscalls: # please leave it empty if you have no idea how seccomp works +proxy: + socks5: '' + http: '' + https: '' diff --git a/docker/volumes/sandbox/conf/config.yaml.example b/docker/volumes/sandbox/conf/config.yaml.example new file mode 100644 index 0000000000..f92c19e51a --- /dev/null +++ b/docker/volumes/sandbox/conf/config.yaml.example @@ -0,0 +1,35 @@ +app: + port: 8194 + debug: True + key: dify-sandbox +max_workers: 4 +max_requests: 50 +worker_timeout: 5 +python_path: /usr/local/bin/python3 +python_lib_path: + - /usr/local/lib/python3.10 + - /usr/lib/python3.10 + - /usr/lib/python3 + - /usr/lib/x86_64-linux-gnu + - /etc/ssl/certs/ca-certificates.crt + - /etc/nsswitch.conf + - /etc/hosts + - /etc/resolv.conf + - /run/systemd/resolve/stub-resolv.conf + - /run/resolvconf/resolv.conf + - /etc/localtime + - /usr/share/zoneinfo + - /etc/timezone + # add more paths if needed +python_pip_mirror_url: https://pypi.tuna.tsinghua.edu.cn/simple +nodejs_path: /usr/local/bin/node +enable_network: True +allowed_syscalls: + - 1 + - 2 + - 3 + # add all the syscalls which you require +proxy: + socks5: '' + http: '' + https: '' diff --git a/web/app/(commonLayout)/datasets/DatasetFooter.tsx b/web/app/(commonLayout)/datasets/DatasetFooter.tsx index 6eac815a1a..b87098000f 100644 --- a/web/app/(commonLayout)/datasets/DatasetFooter.tsx +++ b/web/app/(commonLayout)/datasets/DatasetFooter.tsx @@ -9,8 +9,8 @@ const DatasetFooter = () => {

{t('dataset.didYouKnow')}

- {t('dataset.intro1')}{t('dataset.intro2')}{t('dataset.intro3')}
- {t('dataset.intro4')}{t('dataset.intro5')}{t('dataset.intro6')} + {t('dataset.intro1')}{t('dataset.intro2')}{t('dataset.intro3')}
+ {t('dataset.intro4')}{t('dataset.intro5')}{t('dataset.intro6')}

) diff --git a/web/app/(commonLayout)/datasets/Doc.tsx b/web/app/(commonLayout)/datasets/Doc.tsx index a6dd8c23ef..553dca5008 100644 --- a/web/app/(commonLayout)/datasets/Doc.tsx +++ b/web/app/(commonLayout)/datasets/Doc.tsx @@ -1,6 +1,6 @@ 'use client' -import type { FC } from 'react' +import { type FC, useEffect } from 'react' import { useContext } from 'use-context-selector' import TemplateEn from './template/template.en.mdx' import TemplateZh from './template/template.zh.mdx' @@ -14,6 +14,13 @@ const Doc: FC = ({ apiBaseUrl, }) => { const { locale } = useContext(I18n) + + useEffect(() => { + const hash = location.hash + if (hash) + document.querySelector(hash)?.scrollIntoView() + }, []) + return (
{ diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index e264fd707e..263230d049 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -20,17 +20,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
- This api is based on an existing Knowledge and creates a new document through text based on this Knowledge. + This API is based on an existing knowledge and creates a new document through text based on this knowledge. ### Params @@ -50,7 +50,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Index mode - high_quality High quality: embedding using embedding model, built as vector database index - - economy Economy: Build using inverted index of Keyword Table Index + - economy Economy: Build using inverted index of keyword table index Processing rules @@ -62,7 +62,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - remove_extra_spaces Replace consecutive spaces, newlines, tabs - remove_urls_emails Delete URL, email address - enabled (bool) Whether to select this rule or not. If no document ID is passed in, it represents the default value. - - segmentation (object) segmentation rules + - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 @@ -72,11 +72,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_text' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create-by-text' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -123,17 +123,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
- This api is based on an existing Knowledge and creates a new document through a file based on this Knowledge. + This API is based on an existing knowledge and creates a new document through a file based on this knowledge. ### Params @@ -145,17 +145,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - - original_document_id Source document ID (optional) + - original_document_id Source document ID (optional) - Used to re-upload the document or modify the document cleaning and segmentation configuration. The missing information is copied from the source document - The source document cannot be an archived document - When original_document_id is passed in, the update operation is performed on behalf of the document. process_rule is a fillable item. If not filled in, the segmentation method of the source document will be used by default - When original_document_id is not passed in, the new operation is performed on behalf of the document, and process_rule is required - - indexing_technique Index mode + - indexing_technique Index mode - high_quality High quality: embedding using embedding model, built as vector database index - - economy Economy: Build using inverted index of Keyword Table Index + - economy Economy: Build using inverted index of keyword table index - - process_rule Processing rules + - process_rule Processing rules - mode (string) Cleaning, segmentation mode, automatic / custom - rules (object) Custom rules (in automatic mode, this field is empty) - pre_processing_rules (array[object]) Preprocessing rules @@ -164,7 +164,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - remove_extra_spaces Replace consecutive spaces, newlines, tabs - remove_urls_emails Delete URL, email address - enabled (bool) Whether to select this rule or not. If no document ID is passed in, it represents the default value. - - segmentation (object) segmentation rules + - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 @@ -177,11 +177,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create-by-file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -221,12 +221,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -240,9 +240,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Knowledge description (optional) - Index Technique (optional) - - high_quality high_quality - - economy economy + Index technique (optional) + - high_quality High quality + - economy Economy Permission @@ -252,21 +252,21 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Provider (optional, default: vendor) - - vendor vendor - - external external knowledge + - vendor Vendor + - external External knowledge - External Knowledge api id (optional) + External knowledge API ID (optional) - External Knowledge id (optional) + External knowledge ID (optional) - @@ -306,12 +306,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -327,9 +327,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - @@ -369,12 +369,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -406,17 +406,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
- This api is based on an existing Knowledge and updates the document through text based on this Knowledge. + This API is based on an existing knowledge and updates the document through text based on this knowledge. ### Params @@ -446,7 +446,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - remove_extra_spaces Replace consecutive spaces, newlines, tabs - remove_urls_emails Delete URL, email address - enabled (bool) Whether to select this rule or not. If no document ID is passed in, it represents the default value. - - segmentation (object) segmentation rules + - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 @@ -456,11 +456,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update_by_text' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update-by-text' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -503,17 +503,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
- This api is based on an existing Knowledge, and updates documents through files based on this Knowledge + This API is based on an existing knowledge, and updates documents through files based on this knowledge ### Params @@ -543,7 +543,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - remove_extra_spaces Replace consecutive spaces, newlines, tabs - remove_urls_emails Delete URL, email address - enabled (bool) Whether to select this rule or not. If no document ID is passed in, it represents the default value. - - segmentation (object) segmentation rules + - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 @@ -553,11 +553,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update-by-file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -597,12 +597,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -652,12 +652,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -694,12 +694,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -714,13 +714,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Query - Search keywords, currently only search document names(optional) + Search keywords, currently only search document names (optional) - Page number(optional) + Page number (optional) - Number of items returned, default 20, range 1-100(optional) + Number of items returned, default 20, range 1-100 (optional) @@ -769,12 +769,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -792,9 +792,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - - content (text) Text content/question content, required - - answer (text) Answer content, if the mode of the Knowledge is qa mode, pass the value(optional) - - keywords (list) Keywords(optional) + - content (text) Text content / question content, required + - answer (text) Answer content, if the mode of the knowledge is Q&A mode, pass the value (optional) + - keywords (list) Keywords (optional) @@ -855,12 +855,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -878,10 +878,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Query - keyword,choosable + Keyword (optional) - Search status,completed + Search status, completed @@ -933,12 +933,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -979,12 +979,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -1005,10 +1005,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - - content (text) text content/question content,required - - answer (text) Answer content, not required, passed if the Knowledge is in qa mode - - keywords (list) keyword, not required - - enabled (bool) false/true, not required + - content (text) Text content / question content, required + - answer (text) Answer content, passed if the knowledge is in Q&A mode (optional) + - keywords (list) Keyword (optional) + - enabled (bool) False / true (optional) @@ -1067,41 +1067,41 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
### Path - Dataset ID + Knowledge ID ### Request Body - retrieval keywordc + Query keyword - retrieval keyword(Optional, if not filled, it will be recalled according to the default method) + Retrieval model (optional, if not filled, it will be recalled according to the default method) - search_method (text) Search method: One of the following four keywords is required - keyword_search Keyword search - semantic_search Semantic search - full_text_search Full-text search - hybrid_search Hybrid search - - reranking_enable (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search - - reranking_mode (object) Rerank model configuration, optional, required if reranking is enabled + - reranking_enable (bool) Whether to enable reranking, required if the search mode is semantic_search or hybrid_search (optional) + - reranking_mode (object) Rerank model configuration, required if reranking is enabled - reranking_provider_name (string) Rerank model provider - reranking_model_name (string) Rerank model name - weights (double) Semantic search weight setting in hybrid search mode - - top_k (integer) Number of results to return, optional + - top_k (integer) Number of results to return (optional) - score_threshold_enabled (bool) Whether to enable score threshold - score_threshold (double) Score threshold @@ -1114,26 +1114,26 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -1212,7 +1212,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index 5d52664db4..9c25d1e7bb 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -20,13 +20,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -50,7 +50,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 索引方式 - high_quality 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - - economy 经济:使用 Keyword Table Index 的倒排索引进行构建 + - economy 经济:使用 keyword table index 的倒排索引进行构建 处理规则 @@ -64,7 +64,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 @@ -72,11 +72,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_text' \ + curl --location --request --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create-by-text' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -123,13 +123,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -145,17 +145,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - - original_document_id 源文档 ID (选填) + - original_document_id 源文档 ID(选填) - 用于重新上传文档或修改文档清洗、分段配置,缺失的信息从源文档复制 - 源文档不可为归档的文档 - 当传入 original_document_id 时,代表文档进行更新操作,process_rule 为可填项目,不填默认使用源文档的分段方式 - 未传入 original_document_id 时,代表文档进行新增操作,process_rule 为必填 - - indexing_technique 索引方式 + - indexing_technique 索引方式 - high_quality 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - - economy 经济:使用 Keyword Table Index 的倒排索引进行构建 + - economy 经济:使用 keyword table index 的倒排索引进行构建 - - process_rule 处理规则 + - process_rule 处理规则 - mode (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 - rules (object) 自定义规则(自动模式下,该字段为空) - pre_processing_rules (array[object]) 预处理规则 @@ -166,7 +166,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 需要上传的文件。 @@ -177,11 +177,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create-by-file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -221,7 +221,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
economy 经济 - 权限(选填,默认only_me) + 权限(选填,默认 only_me) - only_me 仅自己 - all_team_members 所有团队成员 - partial_members 部分团队成员 - provider,(选填,默认 vendor) + Provider(选填,默认 vendor) - vendor 上传文件 - external 外部知识库 @@ -264,9 +264,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - @@ -306,7 +306,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
---- +
---- +
---- +
@@ -431,7 +431,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - 文档名称 (选填) + 文档名称(选填) 文档内容(选填) @@ -448,7 +448,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 @@ -456,11 +456,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update_by_text' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update-by-text' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -503,13 +503,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -528,7 +528,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - 文档名称 (选填) + 文档名称(选填) 需要上传的文件 @@ -545,7 +545,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 @@ -553,11 +553,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update-by-file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -597,7 +597,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
---- +
---- +
---- +
- content (text) 文本内容/问题内容,必填 - - answer (text) 答案内容,非必填,如果知识库的模式为qa模式则传值 + - answer (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 - keywords (list) 关键字,非必填 @@ -855,7 +855,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
---- +
---- +
- content (text) 文本内容/问题内容,必填 - - answer (text) 答案内容,非必填,如果知识库的模式为qa模式则传值 + - answer (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 - keywords (list) 关键字,非必填 - enabled (bool) false/true,非必填 @@ -1068,13 +1068,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -1088,23 +1088,23 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - 召回关键词 + 检索关键词 - 召回参数(选填,如不填,按照默认方式召回) + 检索参数(选填,如不填,按照默认方式召回) - search_method (text) 检索方法:以下三个关键字之一,必填 - keyword_search 关键字检索 - semantic_search 语义检索 - full_text_search 全文检索 - hybrid_search 混合检索 - - reranking_enable (bool) 是否启用 Reranking,非必填,如果检索模式为semantic_search模式或者hybrid_search则传值 + - reranking_enable (bool) 是否启用 Reranking,非必填,如果检索模式为 semantic_search 模式或者 hybrid_search 则传值 - reranking_mode (object) Rerank模型配置,非必填,如果启用了 reranking 则传值 - reranking_provider_name (string) Rerank 模型提供商 - reranking_model_name (string) Rerank 模型名称 - weights (double) 混合检索模式下语意检索的权重设置 - top_k (integer) 返回结果数量,非必填 - - score_threshold_enabled (bool) 是否开启Score阈值 - - score_threshold (double) Score阈值 + - score_threshold_enabled (bool) 是否开启 score 阈值 + - score_threshold (double) Score 阈值 未启用字段 @@ -1115,26 +1115,26 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -1214,7 +1214,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
diff --git a/web/app/components/app/configuration/dataset-config/index.tsx b/web/app/components/app/configuration/dataset-config/index.tsx index 2c082d8815..0d9d575c1e 100644 --- a/web/app/components/app/configuration/dataset-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/index.tsx @@ -15,6 +15,7 @@ import { AppType } from '@/types/app' import type { DataSet } from '@/models/datasets' import { getMultipleRetrievalConfig, + getSelectedDatasetsMode, } from '@/app/components/workflow/nodes/knowledge-retrieval/utils' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' @@ -38,6 +39,7 @@ const DatasetConfig: FC = () => { isAgent, datasetConfigs, setDatasetConfigs, + setRerankSettingModalOpen, } = useContext(ConfigContext) const formattingChangedDispatcher = useFormattingChangedDispatcher() @@ -55,6 +57,20 @@ const DatasetConfig: FC = () => { ...(datasetConfigs as any), ...retrievalConfig, }) + const { + allExternal, + allInternal, + mixtureInternalAndExternal, + mixtureHighQualityAndEconomic, + inconsistentEmbeddingModel, + } = getSelectedDatasetsMode(filteredDataSets) + + if ( + (allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)) + || mixtureInternalAndExternal + || allExternal + ) + setRerankSettingModalOpen(true) formattingChangedDispatcher() } diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 6b1983f5e2..75f0c33349 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -266,7 +266,7 @@ const ConfigContent: FC = ({
{ - selectedDatasetsMode.allEconomic && ( + selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
{ let errMsg = '' if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) { - if (!tempDataSetConfigs.reranking_model?.reranking_model_name && (rerankDefaultModel && !isRerankDefaultModelValid)) + if (tempDataSetConfigs.reranking_enable + && tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel + && !isRerankDefaultModelValid + ) errMsg = t('appDebug.datasetConfig.rerankModelRequired') } if (errMsg) { @@ -62,7 +66,9 @@ const ParamsConfig = ({ if (!isValid()) return const config = { ...tempDataSetConfigs } - if (config.retrieval_model === RETRIEVE_TYPE.multiWay && !config.reranking_model) { + if (config.retrieval_model === RETRIEVE_TYPE.multiWay + && config.reranking_mode === RerankingModeEnum.RerankingModel + && !config.reranking_model) { config.reranking_model = { reranking_provider_name: rerankDefaultModel?.provider?.provider, reranking_model_name: rerankDefaultModel?.model, diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index af50fc65c3..bf6c5e79c8 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -252,12 +252,18 @@ const Configuration: FC = () => { } hideSelectDataSet() const { - allEconomic, + allExternal, + allInternal, + mixtureInternalAndExternal, mixtureHighQualityAndEconomic, inconsistentEmbeddingModel, } = getSelectedDatasetsMode(newDatasets) - if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel) + if ( + (allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)) + || mixtureInternalAndExternal + || allExternal + ) setRerankSettingModalOpen(true) const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 22585aa678..4c12cab581 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -36,6 +36,7 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import TextGeneration from '@/app/components/app/text-generate/item' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import MessageLogModal from '@/app/components/base/message-log-modal' +import PromptLogModal from '@/app/components/base/prompt-log-modal' import { useStore as useAppStore } from '@/app/components/app/store' import { useAppContext } from '@/context/app-context' import useTimestamp from '@/hooks/use-timestamp' @@ -168,11 +169,13 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { const { userProfile: { timezone } } = useAppContext() const { formatTime } = useTimestamp() const { onClose, appDetail } = useContext(DrawerContext) - const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({ + const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({ currentLogItem: state.currentLogItem, setCurrentLogItem: state.setCurrentLogItem, showMessageLogModal: state.showMessageLogModal, setShowMessageLogModal: state.setShowMessageLogModal, + showPromptLogModal: state.showPromptLogModal, + setShowPromptLogModal: state.setShowPromptLogModal, currentLogModalActiveTab: state.currentLogModalActiveTab, }))) const { t } = useTranslation() @@ -192,8 +195,8 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { conversation_id: detail.id, limit: 10, } - if (allChatItems.at(-1)?.id) - params.first_id = allChatItems.at(-1)?.id.replace('question-', '') + if (allChatItems[0]?.id) + params.first_id = allChatItems[0]?.id.replace('question-', '') const messageRes = await fetchChatMessages({ url: `/apps/${appDetail?.id}/chat-messages`, params, @@ -557,6 +560,16 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { defaultTab={currentLogModalActiveTab} /> )} + {showPromptLogModal && ( + { + setCurrentLogItem() + setShowPromptLogModal(false) + }} + /> + )}
) } diff --git a/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap b/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap index 070975bfa7..7da09c4529 100644 --- a/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap +++ b/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap @@ -1804,6 +1804,280 @@ exports[`build chat item tree and get thread messages should get thread messages ] `; +exports[`build chat item tree and get thread messages should work with partial messages 1`] = ` +[ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105809, + "files": [], + "id": "1019cd79-d141-4f9f-880a-fc1441cfd802", + "message_id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "observation": "", + "position": 1, + "thought": "Sure! My number is 54. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105822, + "files": [], + "id": "0773bec7-b992-4a53-92b2-20ebaeae8798", + "message_id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "observation": "", + "position": 1, + "thought": "My number is 4729. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 4729. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4729. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.30", + "time": "09/11/2024 09:50 PM", + "tokens": 66, + }, + "parentMessageId": "question-324bce32-c98c-435d-a66b-bac974ebb5ed", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "3306", + "id": "question-324bce32-c98c-435d-a66b-bac974ebb5ed", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + }, + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726107812, + "files": [], + "id": "5ca650f3-982c-4399-8b95-9ea241c76707", + "message_id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "observation": "", + "position": 1, + "thought": "My number is 4821. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726111024, + "files": [], + "id": "095cacab-afad-4387-a41d-1662578b8b13", + "message_id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "observation": "", + "position": 1, + "thought": "My number is 1456. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 1456. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "input": { + "inputs": {}, + "query": "1003", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "1003", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 1456. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.38", + "time": "09/11/2024 11:17 PM", + "tokens": 86, + }, + "parentMessageId": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "1003", + "id": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "isAnswer": false, + "message_files": [], + "parentMessageId": "684b5396-4e91-4043-88e9-aabe48b21acc", + }, + ], + "content": "My number is 4821. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.48", + "time": "09/11/2024 10:23 PM", + "tokens": 66, + }, + "parentMessageId": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "siblingIndex": 1, + "workflow_run_id": null, + }, + ], + "content": "3306", + "id": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + }, + ], + "content": "Sure! My number is 54. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "input": { + "inputs": {}, + "query": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.52", + "time": "09/11/2024 09:50 PM", + "tokens": 46, + }, + "parentMessageId": "question-cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "id": "question-cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "isAnswer": false, + "message_files": [], + }, +] +`; + exports[`build chat item tree and get thread messages should work with real world messages 1`] = ` [ { diff --git a/web/app/components/base/chat/__tests__/utils.spec.ts b/web/app/components/base/chat/__tests__/utils.spec.ts index c602ac8a99..1dead1c949 100644 --- a/web/app/components/base/chat/__tests__/utils.spec.ts +++ b/web/app/components/base/chat/__tests__/utils.spec.ts @@ -255,4 +255,10 @@ describe('build chat item tree and get thread messages', () => { const threadMessages6_2 = getThreadMessages(tree6, 'ff4c2b43-48a5-47ad-9dc5-08b34ddba61b') expect(threadMessages6_2).toMatchSnapshot() }) + + const partialMessages = (realWorldMessages as ChatItemInTree[]).slice(-10) + const tree7 = buildChatItemTree(partialMessages) + it('should work with partial messages', () => { + expect(tree7).toMatchSnapshot() + }) }) diff --git a/web/app/components/base/chat/utils.ts b/web/app/components/base/chat/utils.ts index 16357361cf..61dfaecffc 100644 --- a/web/app/components/base/chat/utils.ts +++ b/web/app/components/base/chat/utils.ts @@ -134,6 +134,12 @@ function buildChatItemTree(allMessages: IChatItem[]): ChatItemInTree[] { } } + // If no messages have parentMessageId=null (indicating a root node), + // then we likely have a partial chat history. In this case, + // use the first available message as the root node. + if (rootNodes.length === 0 && allMessages.length > 0) + rootNodes.push(map[allMessages[0]!.id]!) + return rootNodes } diff --git a/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx b/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx index d22d6ff4ec..2a042bab40 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx @@ -1,6 +1,5 @@ import { memo, - useMemo, } from 'react' import { RiDeleteBinLine, @@ -35,17 +34,9 @@ const FileInAttachmentItem = ({ onRemove, onReUpload, }: FileInAttachmentItemProps) => { - const { id, name, type, progress, supportFileType, base64Url, url } = file - const ext = getFileExtension(name, type) + const { id, name, type, progress, supportFileType, base64Url, url, isRemote } = file + const ext = getFileExtension(name, type, isRemote) const isImageFile = supportFileType === SupportUploadFileTypes.image - const nameArr = useMemo(() => { - const nameMatch = name.match(/(.+)\.([^.]+)$/) - - if (nameMatch) - return [nameMatch[1], nameMatch[2]] - - return [name, ''] - }, [name]) return (
-
{nameArr[0]}
- { - nameArr[1] && ( - .{nameArr[1]} - ) - } +
{name}
{ @@ -93,7 +79,11 @@ const FileInAttachmentItem = ({ ) } - {formatFileSize(file.size || 0)} + { + !!file.size && ( + {formatFileSize(file.size)} + ) + }
diff --git a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx index 6597373020..a051b89ec1 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx @@ -31,8 +31,8 @@ const FileItem = ({ onRemove, onReUpload, }: FileItemProps) => { - const { id, name, type, progress, url } = file - const ext = getFileExtension(name, type) + const { id, name, type, progress, url, isRemote } = file + const ext = getFileExtension(name, type, isRemote) const uploadError = progress === -1 return ( @@ -75,7 +75,9 @@ const FileItem = ({ ) } - {formatFileSize(file.size || 0)} + { + !!file.size && formatFileSize(file.size) + }
{ showDownloadAction && ( diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index 942e5d612a..a78c414913 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -25,7 +25,7 @@ import { TransferMethod } from '@/types/app' import { SupportUploadFileTypes } from '@/app/components/workflow/types' import type { FileUpload } from '@/app/components/base/features/types' import { formatFileSize } from '@/utils/format' -import { fetchRemoteFileInfo } from '@/service/common' +import { uploadRemoteFileInfo } from '@/service/common' import type { FileUploadConfigResponse } from '@/models/common' export const useFileSizeLimit = (fileUploadConfig?: FileUploadConfigResponse) => { @@ -49,7 +49,7 @@ export const useFile = (fileConfig: FileUpload) => { const params = useParams() const { imgSizeLimit, docSizeLimit, audioSizeLimit, videoSizeLimit } = useFileSizeLimit(fileConfig.fileUploadConfig) - const checkSizeLimit = (fileType: string, fileSize: number) => { + const checkSizeLimit = useCallback((fileType: string, fileSize: number) => { switch (fileType) { case SupportUploadFileTypes.image: { if (fileSize > imgSizeLimit) { @@ -120,7 +120,7 @@ export const useFile = (fileConfig: FileUpload) => { return true } } - } + }, [audioSizeLimit, docSizeLimit, imgSizeLimit, notify, t, videoSizeLimit]) const handleAddFile = useCallback((newFile: FileEntity) => { const { @@ -188,6 +188,17 @@ export const useFile = (fileConfig: FileUpload) => { } }, [fileStore, notify, t, handleUpdateFile, params]) + const startProgressTimer = useCallback((fileId: string) => { + const timer = setInterval(() => { + const files = fileStore.getState().files + const file = files.find(file => file.id === fileId) + + if (file && file.progress < 80 && file.progress >= 0) + handleUpdateFile({ ...file, progress: file.progress + 20 }) + else + clearTimeout(timer) + }, 200) + }, [fileStore, handleUpdateFile]) const handleLoadFileFromLink = useCallback((url: string) => { const allowedFileTypes = fileConfig.allowed_file_types @@ -197,19 +208,27 @@ export const useFile = (fileConfig: FileUpload) => { type: '', size: 0, progress: 0, - transferMethod: TransferMethod.remote_url, + transferMethod: TransferMethod.local_file, supportFileType: '', url, + isRemote: true, } handleAddFile(uploadingFile) + startProgressTimer(uploadingFile.id) - fetchRemoteFileInfo(url).then((res) => { + uploadRemoteFileInfo(url).then((res) => { const newFile = { ...uploadingFile, - type: res.file_type, - size: res.file_length, + type: res.mime_type, + size: res.size, progress: 100, - supportFileType: getSupportFileType(url, res.file_type, allowedFileTypes?.includes(SupportUploadFileTypes.custom)), + supportFileType: getSupportFileType(res.name, res.mime_type, allowedFileTypes?.includes(SupportUploadFileTypes.custom)), + uploadedId: res.id, + url: res.url, + } + if (!isAllowedFileExtension(res.name, res.mime_type, fileConfig.allowed_file_types || [], fileConfig.allowed_file_extensions || [])) { + notify({ type: 'error', message: t('common.fileUploader.fileExtensionNotSupport') }) + handleRemoveFile(uploadingFile.id) } if (!checkSizeLimit(newFile.supportFileType, newFile.size)) handleRemoveFile(uploadingFile.id) @@ -219,7 +238,7 @@ export const useFile = (fileConfig: FileUpload) => { notify({ type: 'error', message: t('common.fileUploader.pasteFileLinkInvalid') }) handleRemoveFile(uploadingFile.id) }) - }, [checkSizeLimit, handleAddFile, handleUpdateFile, notify, t, handleRemoveFile, fileConfig?.allowed_file_types]) + }, [checkSizeLimit, handleAddFile, handleUpdateFile, notify, t, handleRemoveFile, fileConfig?.allowed_file_types, fileConfig.allowed_file_extensions, startProgressTimer]) const handleLoadFileFromLinkSuccess = useCallback(() => { }, []) diff --git a/web/app/components/base/file-uploader/types.ts b/web/app/components/base/file-uploader/types.ts index ac4584bb4c..285023f0af 100644 --- a/web/app/components/base/file-uploader/types.ts +++ b/web/app/components/base/file-uploader/types.ts @@ -29,4 +29,5 @@ export type FileEntity = { uploadedId?: string base64Url?: string url?: string + isRemote?: boolean } diff --git a/web/app/components/base/file-uploader/utils.ts b/web/app/components/base/file-uploader/utils.ts index 4c7ef0d89b..eb9199d74b 100644 --- a/web/app/components/base/file-uploader/utils.ts +++ b/web/app/components/base/file-uploader/utils.ts @@ -43,10 +43,13 @@ export const fileUpload: FileUpload = ({ }) } -export const getFileExtension = (fileName: string, fileMimetype: string) => { +export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => { if (fileMimetype) return mime.getExtension(fileMimetype) || '' + if (isRemote) + return '' + if (fileName) { const fileNamePair = fileName.split('.') const fileNamePairLength = fileNamePair.length diff --git a/web/app/components/base/image-uploader/image-list.tsx b/web/app/components/base/image-uploader/image-list.tsx index 8d5d1a1af5..35f6149b13 100644 --- a/web/app/components/base/image-uploader/image-list.tsx +++ b/web/app/components/base/image-uploader/image-list.tsx @@ -133,6 +133,7 @@ const ImageList: FC = ({ setImagePreviewUrl('')} + title='' /> )}
diff --git a/web/app/components/base/search-input/index.tsx b/web/app/components/base/search-input/index.tsx index 4b3821da5a..89345fbe32 100644 --- a/web/app/components/base/search-input/index.tsx +++ b/web/app/components/base/search-input/index.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import { useState } from 'react' +import { useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { RiSearchLine } from '@remixicon/react' import cn from '@/utils/classnames' @@ -12,6 +12,7 @@ type SearchInputProps = { onChange: (v: string) => void white?: boolean } + const SearchInput: FC = ({ placeholder, className, @@ -21,6 +22,7 @@ const SearchInput: FC = ({ }) => { const { t } = useTranslation() const [focus, setFocus] = useState(false) + const isComposing = useRef(false) return (
= ({ placeholder={placeholder || t('common.operation.search')!} value={value} onChange={(e) => { - onChange(e.target.value) + if (!isComposing.current) + onChange(e.target.value) + }} + onCompositionStart={() => { + isComposing.current = true + }} + onCompositionEnd={() => { + isComposing.current = false }} onFocus={() => setFocus(true)} onBlur={() => setFocus(false)} diff --git a/web/app/components/develop/md.tsx b/web/app/components/develop/md.tsx index 87f7b35aaf..26b4007c87 100644 --- a/web/app/components/develop/md.tsx +++ b/web/app/components/develop/md.tsx @@ -39,6 +39,7 @@ export const Heading = function H2({ } return ( <> +
{method} {/* */} diff --git a/web/app/components/develop/template/template_advanced_chat.en.mdx b/web/app/components/develop/template/template_advanced_chat.en.mdx index 7d80367ce4..6642c5cedc 100644 --- a/web/app/components/develop/template/template_advanced_chat.en.mdx +++ b/web/app/components/develop/template/template_advanced_chat.en.mdx @@ -656,6 +656,11 @@ Chat applications support session persistence, allowing previous chat history to Return only pinned conversations as `true`, only non-pinned as `false` + + Sorting Field (Optional), Default: -updated_at (sorted in descending order by update time) + - Available Values: created_at, -created_at, updated_at, -updated_at + - The symbol before the field represents the order or reverse, "-" represents reverse order. + ### Response diff --git a/web/app/components/develop/template/template_advanced_chat.zh.mdx b/web/app/components/develop/template/template_advanced_chat.zh.mdx index 690d700f05..8e64d63ac5 100755 --- a/web/app/components/develop/template/template_advanced_chat.zh.mdx +++ b/web/app/components/develop/template/template_advanced_chat.zh.mdx @@ -691,6 +691,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 只返回置顶 true,只返回非置顶 false + + 排序字段(选题),默认 -updated_at(按更新时间倒序排列) + - 可选值:created_at, -created_at, updated_at, -updated_at + - 字段前面的符号代表顺序或倒序,-代表倒序 + ### Response diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index 907a1ab0b4..a94016ca3a 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -690,6 +690,11 @@ Chat applications support session persistence, allowing previous chat history to Return only pinned conversations as `true`, only non-pinned as `false` + + Sorting Field (Optional), Default: -updated_at (sorted in descending order by update time) + - Available Values: created_at, -created_at, updated_at, -updated_at + - The symbol before the field represents the order or reverse, "-" represents reverse order. + ### Response diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index f6dc7daa1e..92b13b2c7d 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -705,6 +705,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 只返回置顶 true,只返回非置顶 false + + 排序字段(选题),默认 -updated_at(按更新时间倒序排列) + - 可选值:created_at, -created_at, updated_at, -updated_at + - 字段前面的符号代表顺序或倒序,-代表倒序 + ### Response diff --git a/web/app/components/workflow/nodes/_base/components/editor/base.tsx b/web/app/components/workflow/nodes/_base/components/editor/base.tsx index cca565c39d..44930427ae 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/base.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/base.tsx @@ -26,7 +26,7 @@ type Props = { isFocus: boolean isInNode?: boolean onGenerated?: (prompt: string) => void - codeLanguages: CodeLanguage + codeLanguages?: CodeLanguage fileList?: FileEntity[] showFileList?: boolean showCodeGenerator?: boolean @@ -78,7 +78,7 @@ const Base: FC = ({ e.stopPropagation() }}> {headerRight} - {showCodeGenerator && ( + {showCodeGenerator && codeLanguages && (
diff --git a/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx b/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx index a31cde2c3c..28d07936d3 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx @@ -31,6 +31,7 @@ export type Props = { noWrapper?: boolean isExpand?: boolean showFileList?: boolean + onGenerated?: (value: string) => void showCodeGenerator?: boolean } @@ -64,6 +65,7 @@ const CodeEditor: FC = ({ noWrapper, isExpand, showFileList, + onGenerated, showCodeGenerator = false, }) => { const [isFocus, setIsFocus] = React.useState(false) @@ -151,9 +153,6 @@ const CodeEditor: FC = ({ return isFocus ? 'focus-theme' : 'blur-theme' })() - const handleGenerated = (code: string) => { - handleEditorChange(code) - } const main = ( <> @@ -205,7 +204,7 @@ const CodeEditor: FC = ({ isFocus={isFocus && !readOnly} minHeight={minHeight} isInNode={isInNode} - onGenerated={handleGenerated} + onGenerated={onGenerated} codeLanguages={language} fileList={fileList} showFileList={showFileList} diff --git a/web/app/components/workflow/nodes/code/code-parser.spec.ts b/web/app/components/workflow/nodes/code/code-parser.spec.ts new file mode 100644 index 0000000000..b5d28dd136 --- /dev/null +++ b/web/app/components/workflow/nodes/code/code-parser.spec.ts @@ -0,0 +1,326 @@ +import { VarType } from '../../types' +import { extractFunctionParams, extractReturnType } from './code-parser' +import { CodeLanguage } from './types' + +const SAMPLE_CODES = { + python3: { + noParams: 'def main():', + singleParam: 'def main(param1):', + multipleParams: `def main(param1, param2, param3): + return {"result": param1}`, + withTypes: `def main(param1: str, param2: int, param3: List[str]): + result = process_data(param1, param2) + return {"output": result}`, + withDefaults: `def main(param1: str = "default", param2: int = 0): + return {"data": param1}`, + }, + javascript: { + noParams: 'function main() {', + singleParam: 'function main(param1) {', + multipleParams: `function main(param1, param2, param3) { + return { result: param1 } + }`, + withComments: `// Main function + function main(param1, param2) { + // Process data + return { output: process(param1, param2) } + }`, + withSpaces: 'function main( param1 , param2 ) {', + }, +} + +describe('extractFunctionParams', () => { + describe('Python3', () => { + test('handles no parameters', () => { + const result = extractFunctionParams(SAMPLE_CODES.python3.noParams, CodeLanguage.python3) + expect(result).toEqual([]) + }) + + test('extracts single parameter', () => { + const result = extractFunctionParams(SAMPLE_CODES.python3.singleParam, CodeLanguage.python3) + expect(result).toEqual(['param1']) + }) + + test('extracts multiple parameters', () => { + const result = extractFunctionParams(SAMPLE_CODES.python3.multipleParams, CodeLanguage.python3) + expect(result).toEqual(['param1', 'param2', 'param3']) + }) + + test('handles type hints', () => { + const result = extractFunctionParams(SAMPLE_CODES.python3.withTypes, CodeLanguage.python3) + expect(result).toEqual(['param1', 'param2', 'param3']) + }) + + test('handles default values', () => { + const result = extractFunctionParams(SAMPLE_CODES.python3.withDefaults, CodeLanguage.python3) + expect(result).toEqual(['param1', 'param2']) + }) + }) + + // JavaScriptのテストケース + describe('JavaScript', () => { + test('handles no parameters', () => { + const result = extractFunctionParams(SAMPLE_CODES.javascript.noParams, CodeLanguage.javascript) + expect(result).toEqual([]) + }) + + test('extracts single parameter', () => { + const result = extractFunctionParams(SAMPLE_CODES.javascript.singleParam, CodeLanguage.javascript) + expect(result).toEqual(['param1']) + }) + + test('extracts multiple parameters', () => { + const result = extractFunctionParams(SAMPLE_CODES.javascript.multipleParams, CodeLanguage.javascript) + expect(result).toEqual(['param1', 'param2', 'param3']) + }) + + test('handles comments in code', () => { + const result = extractFunctionParams(SAMPLE_CODES.javascript.withComments, CodeLanguage.javascript) + expect(result).toEqual(['param1', 'param2']) + }) + + test('handles whitespace', () => { + const result = extractFunctionParams(SAMPLE_CODES.javascript.withSpaces, CodeLanguage.javascript) + expect(result).toEqual(['param1', 'param2']) + }) + }) +}) + +const RETURN_TYPE_SAMPLES = { + python3: { + singleReturn: ` +def main(param1): + return {"result": "value"}`, + + multipleReturns: ` +def main(param1, param2): + return {"result": "value", "status": "success"}`, + + noReturn: ` +def main(): + print("Hello")`, + + complexReturn: ` +def main(): + data = process() + return {"result": data, "count": 42, "messages": ["hello"]}`, + nestedObject: ` + def main(name, age, city): + return { + 'personal_info': { + 'name': name, + 'age': age, + 'city': city + }, + 'timestamp': int(time.time()), + 'status': 'active' + }`, + }, + + javascript: { + singleReturn: ` +function main(param1) { + return { result: "value" } +}`, + + multipleReturns: ` +function main(param1) { + return { result: "value", status: "success" } +}`, + + withParentheses: ` +function main() { + return ({ result: "value", status: "success" }) +}`, + + noReturn: ` +function main() { + console.log("Hello") +}`, + + withQuotes: ` +function main() { + return { "result": 'value', 'status': "success" } +}`, + nestedObject: ` +function main(name, age, city) { + return { + personal_info: { + name: name, + age: age, + city: city + }, + timestamp: Date.now(), + status: 'active' + } +}`, + withJSDoc: ` +/** + * Creates a user profile with personal information and metadata + * @param {string} name - The user's name + * @param {number} age - The user's age + * @param {string} city - The user's city of residence + * @returns {Object} An object containing the user profile + */ +function main(name, age, city) { + return { + result: { + personal_info: { + name: name, + age: age, + city: city + }, + timestamp: Date.now(), + status: 'active' + } + }; +}`, + + }, +} + +describe('extractReturnType', () => { + // Python3のテスト + describe('Python3', () => { + test('extracts single return value', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.python3.singleReturn, CodeLanguage.python3) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + }) + }) + + test('extracts multiple return values', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.python3.multipleReturns, CodeLanguage.python3) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + + test('returns empty object when no return statement', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.python3.noReturn, CodeLanguage.python3) + expect(result).toEqual({}) + }) + + test('handles complex return statement', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.python3.complexReturn, CodeLanguage.python3) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + count: { + type: VarType.string, + children: null, + }, + messages: { + type: VarType.string, + children: null, + }, + }) + }) + test('handles nested object structure', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.python3.nestedObject, CodeLanguage.python3) + expect(result).toEqual({ + personal_info: { + type: VarType.string, + children: null, + }, + timestamp: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + }) + + // JavaScriptのテスト + describe('JavaScript', () => { + test('extracts single return value', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.singleReturn, CodeLanguage.javascript) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + }) + }) + + test('extracts multiple return values', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.multipleReturns, CodeLanguage.javascript) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + + test('handles return with parentheses', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.withParentheses, CodeLanguage.javascript) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + + test('returns empty object when no return statement', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.noReturn, CodeLanguage.javascript) + expect(result).toEqual({}) + }) + + test('handles quoted keys', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.withQuotes, CodeLanguage.javascript) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + test('handles nested object structure', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.nestedObject, CodeLanguage.javascript) + expect(result).toEqual({ + personal_info: { + type: VarType.string, + children: null, + }, + timestamp: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + }) +}) diff --git a/web/app/components/workflow/nodes/code/code-parser.ts b/web/app/components/workflow/nodes/code/code-parser.ts new file mode 100644 index 0000000000..e1b0928f14 --- /dev/null +++ b/web/app/components/workflow/nodes/code/code-parser.ts @@ -0,0 +1,86 @@ +import { VarType } from '../../types' +import type { OutputVar } from './types' +import { CodeLanguage } from './types' + +export const extractFunctionParams = (code: string, language: CodeLanguage) => { + if (language === CodeLanguage.json) + return [] + + const patterns: Record, RegExp> = { + [CodeLanguage.python3]: /def\s+main\s*\((.*?)\)/, + [CodeLanguage.javascript]: /function\s+main\s*\((.*?)\)/, + } + const match = code.match(patterns[language]) + const params: string[] = [] + + if (match?.[1]) { + params.push(...match[1].split(',') + .map(p => p.trim()) + .filter(Boolean) + .map(p => p.split(':')[0].trim()), + ) + } + + return params +} +export const extractReturnType = (code: string, language: CodeLanguage): OutputVar => { + const codeWithoutComments = code.replace(/\/\*\*[\s\S]*?\*\//, '') + console.log(codeWithoutComments) + + const returnIndex = codeWithoutComments.indexOf('return') + if (returnIndex === -1) + return {} + + // returnから始まる部分文字列を取得 + const codeAfterReturn = codeWithoutComments.slice(returnIndex) + + let bracketCount = 0 + let startIndex = codeAfterReturn.indexOf('{') + + if (language === CodeLanguage.javascript && startIndex === -1) { + const parenStart = codeAfterReturn.indexOf('(') + if (parenStart !== -1) + startIndex = codeAfterReturn.indexOf('{', parenStart) + } + + if (startIndex === -1) + return {} + + let endIndex = -1 + + for (let i = startIndex; i < codeAfterReturn.length; i++) { + if (codeAfterReturn[i] === '{') + bracketCount++ + if (codeAfterReturn[i] === '}') { + bracketCount-- + if (bracketCount === 0) { + endIndex = i + 1 + break + } + } + } + + if (endIndex === -1) + return {} + + const returnContent = codeAfterReturn.slice(startIndex + 1, endIndex - 1) + console.log(returnContent) + + const result: OutputVar = {} + + const keyRegex = /['"]?(\w+)['"]?\s*:(?![^{]*})/g + const matches = returnContent.matchAll(keyRegex) + + for (const match of matches) { + console.log(`Found key: "${match[1]}" from match: "${match[0]}"`) + const key = match[1] + result[key] = { + type: VarType.string, + children: null, + } + } + + console.log(result) + + return result +} diff --git a/web/app/components/workflow/nodes/code/panel.tsx b/web/app/components/workflow/nodes/code/panel.tsx index d3e5e58634..08fc565836 100644 --- a/web/app/components/workflow/nodes/code/panel.tsx +++ b/web/app/components/workflow/nodes/code/panel.tsx @@ -5,6 +5,7 @@ import RemoveEffectVarConfirm from '../_base/components/remove-effect-var-confir import useConfig from './use-config' import type { CodeNodeType } from './types' import { CodeLanguage } from './types' +import { extractFunctionParams, extractReturnType } from './code-parser' import VarList from '@/app/components/workflow/nodes/_base/components/variable/var-list' import OutputVarList from '@/app/components/workflow/nodes/_base/components/variable/output-var-list' import AddButton from '@/app/components/base/button/add-button' @@ -12,10 +13,9 @@ import Field from '@/app/components/workflow/nodes/_base/components/field' import Split from '@/app/components/workflow/nodes/_base/components/split' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import TypeSelector from '@/app/components/workflow/nodes/_base/components/selector' -import type { NodePanelProps } from '@/app/components/workflow/types' +import { type NodePanelProps } from '@/app/components/workflow/types' import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import ResultPanel from '@/app/components/workflow/run/result-panel' - const i18nPrefix = 'workflow.nodes.code' const codeLanguages = [ @@ -38,6 +38,7 @@ const Panel: FC> = ({ readOnly, inputs, outputKeyOrders, + handleCodeAndVarsChange, handleVarListChange, handleAddVariable, handleRemoveVariable, @@ -61,6 +62,18 @@ const Panel: FC> = ({ setInputVarValues, } = useConfig(id, data) + const handleGeneratedCode = (value: string) => { + const params = extractFunctionParams(value, inputs.code_language) + const codeNewInput = params.map((p) => { + return { + variable: p, + value_selector: [], + } + }) + const returnTypes = extractReturnType(value, inputs.code_language) + handleCodeAndVarsChange(value, codeNewInput, returnTypes) + } + return (
@@ -92,6 +105,7 @@ const Panel: FC> = ({ language={inputs.code_language} value={inputs.code} onChange={handleCodeChange} + onGenerated={handleGeneratedCode} showCodeGenerator={true} />
diff --git a/web/app/components/workflow/nodes/code/use-config.ts b/web/app/components/workflow/nodes/code/use-config.ts index 07fe85aa0f..c53c07a28e 100644 --- a/web/app/components/workflow/nodes/code/use-config.ts +++ b/web/app/components/workflow/nodes/code/use-config.ts @@ -3,7 +3,7 @@ import produce from 'immer' import useVarList from '../_base/hooks/use-var-list' import useOutputVarList from '../_base/hooks/use-output-var-list' import { BlockEnum, VarType } from '../../types' -import type { Var } from '../../types' +import type { Var, Variable } from '../../types' import { useStore } from '../../store' import type { CodeNodeType, OutputVar } from './types' import { CodeLanguage } from './types' @@ -136,7 +136,15 @@ const useConfig = (id: string, payload: CodeNodeType) => { const setInputVarValues = useCallback((newPayload: Record) => { setRunInputData(newPayload) }, [setRunInputData]) - + const handleCodeAndVarsChange = useCallback((code: string, inputVariables: Variable[], outputVariables: OutputVar) => { + const newInputs = produce(inputs, (draft) => { + draft.code = code + draft.variables = inputVariables + draft.outputs = outputVariables + }) + setInputs(newInputs) + syncOutputKeyOrders(outputVariables) + }, [inputs, setInputs, syncOutputKeyOrders]) return { readOnly, inputs, @@ -163,6 +171,7 @@ const useConfig = (id: string, payload: CodeNodeType) => { inputVarValues, setInputVarValues, runResult, + handleCodeAndVarsChange, } } diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts index d280a2d63e..288a718aa2 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts @@ -240,7 +240,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { if ( (allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)) || mixtureInternalAndExternal - || (allExternal && newDatasets.length > 1) + || allExternal ) setRerankModelOpen(true) }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) diff --git a/web/service/common.ts b/web/service/common.ts index 70586b6ff6..9acbd75940 100644 --- a/web/service/common.ts +++ b/web/service/common.ts @@ -320,9 +320,10 @@ export const verifyForgotPasswordToken: Fetcher = ({ url, body }) => post(url, { body }) -export const fetchRemoteFileInfo = (url: string) => { - return get<{ file_type: string; file_length: number }>(`/remote-files/${url}`) +export const uploadRemoteFileInfo = (url: string) => { + return post<{ id: string; name: string; size: number; mime_type: string; url: string }>('/remote-files/upload', { body: { url } }) } + export const sendEMailLoginCode = (email: string, language = 'en-US') => post('/email-code-login', { body: { email, language } })