diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index e9c2b7b086..c87d5a4dd4 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -78,7 +78,7 @@ jobs: - name: Run Workflow run: poetry run -C api bash dev/pytest/pytest_workflow.sh - - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch) + - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase) uses: hoverkraft-tech/compose-action@v2.0.0 with: compose-file: | @@ -86,6 +86,7 @@ jobs: services: | weaviate qdrant + couchbase-server etcd minio milvus-standalone diff --git a/.github/workflows/expose_service_ports.sh b/.github/workflows/expose_service_ports.sh index ae3e0ee69d..bc65c19a91 100755 --- a/.github/workflows/expose_service_ports.sh +++ b/.github/workflows/expose_service_ports.sh @@ -7,5 +7,7 @@ yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/dock yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml +yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml +yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml -echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch" \ No newline at end of file +echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase" diff --git a/.gitignore b/.gitignore index 27cf8a4ba3..29a80534f7 100644 --- a/.gitignore +++ b/.gitignore @@ -173,6 +173,7 @@ docker/volumes/myscale/log/* docker/volumes/unstructured/* docker/volumes/pgvector/data/* docker/volumes/pgvecto_rs/data/* +docker/volumes/couchbase/* docker/nginx/conf.d/default.conf docker/nginx/ssl/* @@ -189,4 +190,4 @@ pyrightconfig.json api/.vscode .idea/ -.vscode \ No newline at end of file +.vscode diff --git a/README.md b/README.md index f6d14bb840..cd783501e2 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ ![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +

+ 📌 Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast +

+

Dify Cloud · Self-hosting · diff --git a/README_CN.md b/README_CN.md index 689f98ccf4..070951699a 100644 --- a/README_CN.md +++ b/README_CN.md @@ -154,7 +154,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI 我们提供[ Dify 云服务](https://dify.ai),任何人都可以零设置尝试。它提供了自部署版本的所有功能,并在沙盒计划中包含 200 次免费的 GPT-4 调用。 - **自托管 Dify 社区版
** -使用这个[入门指南](#quick-start)快速在您的环境中运行 Dify。 +使用这个[入门指南](#快速启动)快速在您的环境中运行 Dify。 使用我们的[文档](https://docs.dify.ai)进行进一步的参考和更深入的说明。 - **面向企业/组织的 Dify
** diff --git a/api/.env.example b/api/.env.example index 960d8b1879..2ce425338e 100644 --- a/api/.env.example +++ b/api/.env.example @@ -31,8 +31,17 @@ REDIS_HOST=localhost REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 +REDIS_USE_SSL=false REDIS_DB=0 +# redis Sentinel configuration. +REDIS_USE_SENTINEL=false +REDIS_SENTINELS= +REDIS_SENTINEL_SERVICE_NAME= +REDIS_SENTINEL_USERNAME= +REDIS_SENTINEL_PASSWORD= +REDIS_SENTINEL_SOCKET_TIMEOUT=0.1 + # PostgreSQL database configuration DB_USERNAME=postgres DB_PASSWORD=difyai123456 @@ -111,7 +120,7 @@ SUPABASE_URL=your-server-url WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb, upstash +# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash VECTOR_STORE=weaviate # Weaviate configuration @@ -127,6 +136,13 @@ QDRANT_CLIENT_TIMEOUT=20 QDRANT_GRPC_ENABLED=false QDRANT_GRPC_PORT=6334 +#Couchbase configuration +COUCHBASE_CONNECTION_STRING=127.0.0.1 +COUCHBASE_USER=Administrator +COUCHBASE_PASSWORD=password +COUCHBASE_BUCKET_NAME=Embeddings +COUCHBASE_SCOPE_NAME=_default + # Milvus configuration MILVUS_URI=http://127.0.0.1:19530 MILVUS_TOKEN= diff --git a/api/Dockerfile b/api/Dockerfile index c6381859b3..f078181264 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -55,7 +55,9 @@ RUN apt-get update \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.1-1 \ + && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \ + # install a chinese font to support the use of tools like matplotlib + && apt-get install -y fonts-noto-cjk \ && apt-get autoremove -y \ && rm -rf /var/lib/apt/lists/* diff --git a/api/commands.py b/api/commands.py index 720a4447da..da09f1b610 100644 --- a/api/commands.py +++ b/api/commands.py @@ -278,6 +278,7 @@ def migrate_knowledge_vector_database(): VectorType.BAIDU, VectorType.VIKINGDB, VectorType.UPSTASH, + VectorType.COUCHBASE, } page = 1 while True: diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 4d6c9aedc1..0fa926038d 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -571,6 +571,11 @@ class DataSetConfig(BaseSettings): default=False, ) + TIDB_SERVERLESS_NUMBER: PositiveInt = Field( + description="number of tidb serverless cluster", + default=500, + ) + class WorkspaceConfig(BaseSettings): """ diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 705e30b7ff..e8f6ba91b6 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -17,6 +17,7 @@ from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCO from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig from configs.middleware.vdb.chroma_config import ChromaConfig +from configs.middleware.vdb.couchbase_config import CouchbaseConfig from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig from configs.middleware.vdb.milvus_config import MilvusConfig from configs.middleware.vdb.myscale_config import MyScaleConfig @@ -27,6 +28,7 @@ from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig from configs.middleware.vdb.qdrant_config import QdrantConfig from configs.middleware.vdb.relyt_config import RelytConfig from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig +from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig from configs.middleware.vdb.upstash_config import UpstashConfig from configs.middleware.vdb.vikingdb_config import VikingDBConfig @@ -54,6 +56,11 @@ class VectorStoreConfig(BaseSettings): default=None, ) + VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field( + description="Enable whitelist for vector store.", + default=False, + ) + class KeywordStoreConfig(BaseSettings): KEYWORD_STORE: str = Field( @@ -245,8 +252,10 @@ class MiddlewareConfig( TiDBVectorConfig, WeaviateConfig, ElasticsearchConfig, + CouchbaseConfig, InternalTestConfig, VikingDBConfig, UpstashConfig, + TidbOnQdrantConfig, ): pass diff --git a/api/configs/middleware/vdb/couchbase_config.py b/api/configs/middleware/vdb/couchbase_config.py new file mode 100644 index 0000000000..391089ec6e --- /dev/null +++ b/api/configs/middleware/vdb/couchbase_config.py @@ -0,0 +1,34 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class CouchbaseConfig(BaseModel): + """ + Couchbase configs + """ + + COUCHBASE_CONNECTION_STRING: Optional[str] = Field( + description="COUCHBASE connection string", + default=None, + ) + + COUCHBASE_USER: Optional[str] = Field( + description="COUCHBASE user", + default=None, + ) + + COUCHBASE_PASSWORD: Optional[str] = Field( + description="COUCHBASE password", + default=None, + ) + + COUCHBASE_BUCKET_NAME: Optional[str] = Field( + description="COUCHBASE bucket name", + default=None, + ) + + COUCHBASE_SCOPE_NAME: Optional[str] = Field( + description="COUCHBASE scope name", + default=None, + ) diff --git a/api/configs/middleware/vdb/tidb_on_qdrant_config.py b/api/configs/middleware/vdb/tidb_on_qdrant_config.py new file mode 100644 index 0000000000..98268798ef --- /dev/null +++ b/api/configs/middleware/vdb/tidb_on_qdrant_config.py @@ -0,0 +1,65 @@ +from typing import Optional + +from pydantic import Field, NonNegativeInt, PositiveInt +from pydantic_settings import BaseSettings + + +class TidbOnQdrantConfig(BaseSettings): + """ + Tidb on Qdrant configs + """ + + TIDB_ON_QDRANT_URL: Optional[str] = Field( + description="Tidb on Qdrant url", + default=None, + ) + + TIDB_ON_QDRANT_API_KEY: Optional[str] = Field( + description="Tidb on Qdrant api key", + default=None, + ) + + TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( + description="Tidb on Qdrant client timeout in seconds", + default=20, + ) + + TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field( + description="whether enable grpc support for Tidb on Qdrant connection", + default=False, + ) + + TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field( + description="Tidb on Qdrant grpc port", + default=6334, + ) + + TIDB_PUBLIC_KEY: Optional[str] = Field( + description="Tidb account public key", + default=None, + ) + + TIDB_PRIVATE_KEY: Optional[str] = Field( + description="Tidb account private key", + default=None, + ) + + TIDB_API_URL: Optional[str] = Field( + description="Tidb API url", + default=None, + ) + + TIDB_IAM_API_URL: Optional[str] = Field( + description="Tidb IAM API url", + default=None, + ) + + TIDB_REGION: Optional[str] = Field( + description="Tidb serverless region", + default="regions/aws-us-east-1", + ) + + TIDB_PROJECT_ID: Optional[str] = Field( + description="Tidb project id", + default=None, + ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 389a64f53e..3dc87e3058 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.10.1", + default="0.10.2", ) COMMIT_SHA: str = Field( diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 2fba3e0af0..fe06201982 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -105,6 +105,8 @@ class ChatMessageListApi(Resource): if rest_count > 0: has_more = True + history_messages = list(reversed(history_messages)) + return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 6e6792936e..854821746a 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -102,6 +102,13 @@ class DatasetListApi(Resource): help="type is required. Name must be between 1 to 40 characters.", type=_validate_name, ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) parser.add_argument( "indexing_technique", type=str, @@ -140,6 +147,7 @@ class DatasetListApi(Resource): dataset = DatasetService.create_empty_dataset( tenant_id=current_user.current_tenant_id, name=args["name"], + description=args["description"], indexing_technique=args["indexing_technique"], account=current_user, permission=DatasetPermissionEnum.ONLY_ME, @@ -631,6 +639,8 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR + | VectorType.TIDB_ON_QDRANT + | VectorType.COUCHBASE ): return { "retrieval_method": [ @@ -669,6 +679,7 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH + | VectorType.COUCHBASE | VectorType.PGVECTOR ): return { diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index aab7dd7888..7c7580e3c6 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -21,7 +21,12 @@ class AppParameterApi(InstalledAppResource): "options": fields.List(fields.String), } - system_parameters_fields = {"image_file_size_limit": fields.String} + system_parameters_fields = { + "image_file_size_limit": fields.Integer, + "video_file_size_limit": fields.Integer, + "audio_file_size_limit": fields.Integer, + "file_size_limit": fields.Integer, + } parameters_fields = { "opening_statement": fields.String, @@ -82,7 +87,12 @@ class AppParameterApi(InstalledAppResource): } }, ), - "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, + "system_parameters": { + "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, + }, } diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 914b60f263..fee840b30d 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -21,7 +21,7 @@ class EnterpriseWorkspace(Resource): if account is None: return {"message": "owner account not found."}, 404 - tenant = TenantService.create_tenant(args["name"]) + tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) TenantService.create_tenant_member(tenant, account, role="owner") tenant_was_created.send(tenant) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index f7c091217b..9a4cdc26cd 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -21,7 +21,12 @@ class AppParameterApi(Resource): "options": fields.List(fields.String), } - system_parameters_fields = {"image_file_size_limit": fields.String} + system_parameters_fields = { + "image_file_size_limit": fields.Integer, + "video_file_size_limit": fields.Integer, + "audio_file_size_limit": fields.Integer, + "file_size_limit": fields.Integer, + } parameters_fields = { "opening_statement": fields.String, @@ -81,7 +86,12 @@ class AppParameterApi(Resource): } }, ), - "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, + "system_parameters": { + "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, + }, } diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index f076cff6c8..799fccc228 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -66,6 +66,13 @@ class DatasetListApi(DatasetApiResource): help="type is required. Name must be between 1 to 40 characters.", type=_validate_name, ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) parser.add_argument( "indexing_technique", type=str, @@ -108,6 +115,7 @@ class DatasetListApi(DatasetApiResource): dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, name=args["name"], + description=args["description"], indexing_technique=args["indexing_technique"], account=current_user, permission=args["permission"], diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 20b4e4674c..974d2cff94 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -21,7 +21,12 @@ class AppParameterApi(WebApiResource): "options": fields.List(fields.String), } - system_parameters_fields = {"image_file_size_limit": fields.String} + system_parameters_fields = { + "image_file_size_limit": fields.Integer, + "video_file_size_limit": fields.Integer, + "audio_file_size_limit": fields.Integer, + "file_size_limit": fields.Integer, + } parameters_fields = { "opening_statement": fields.String, @@ -80,7 +85,12 @@ class AppParameterApi(WebApiResource): } }, ), - "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, + "system_parameters": { + "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, + }, } diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 514dcfbd68..507455c176 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -165,6 +165,12 @@ class BaseAgentRunner(AppRunner): continue parameter_type = parameter.type.as_normal_type() + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: + continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] @@ -250,6 +256,12 @@ class BaseAgentRunner(AppRunner): continue parameter_type = parameter.type.as_normal_type() + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: + continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 0c6ce8ce75..b69d7a74c0 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -76,8 +76,16 @@ def to_prompt_message_content(f: File, /): def download(f: File, /): - upload_file = file_repository.get_upload_file(session=db.session(), file=f) - return _download_file_content(upload_file.key) + if f.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file = file_repository.get_tool_file(session=db.session(), file=f) + return _download_file_content(tool_file.file_key) + elif f.transfer_method == FileTransferMethod.LOCAL_FILE: + upload_file = file_repository.get_upload_file(session=db.session(), file=f) + return _download_file_content(upload_file.key) + # remote file + response = ssrf_proxy.get(f.remote_url, follow_redirects=True) + response.raise_for_status() + return response.content def _download_file_content(path: str, /): diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 52b590f66a..88531d8ae0 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -105,6 +105,7 @@ class LLMResult(BaseModel): Model class for llm result. """ + id: Optional[str] = None model: str prompt_messages: list[PromptMessage] message: AssistantPromptMessage diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index 093f57c51e..1ef5e83abc 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -53,6 +53,9 @@ model_credential_schema: type: select required: true options: + - label: + en_US: 2024-10-01-preview + value: 2024-10-01-preview - label: en_US: 2024-09-01-preview value: 2024-09-01-preview diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 4f46fb81f8..1cd4823e13 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -45,9 +45,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): stream: bool = True, user: Optional[str] = None, ) -> Union[LLMResult, Generator]: - base_model_name = credentials.get("base_model_name") - if not base_model_name: - raise ValueError("Base Model Name is required") + base_model_name = self._get_base_model_name(credentials) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: @@ -81,9 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None, ) -> int: - base_model_name = credentials.get("base_model_name") - if not base_model_name: - raise ValueError("Base Model Name is required") + base_model_name = self._get_base_model_name(credentials) model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not model_entity: raise ValueError(f"Base Model Name {base_model_name} is invalid") @@ -108,9 +104,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if "base_model_name" not in credentials: raise CredentialsValidateFailedError("Base Model Name is required") - base_model_name = credentials.get("base_model_name") - if not base_model_name: - raise CredentialsValidateFailedError("Base Model Name is required") + base_model_name = self._get_base_model_name(credentials) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not ai_model_entity: @@ -149,9 +143,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - base_model_name = credentials.get("base_model_name") - if not base_model_name: - raise ValueError("Base Model Name is required") + base_model_name = self._get_base_model_name(credentials) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) return ai_model_entity.entity if ai_model_entity else None @@ -308,11 +300,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if tools: extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - # extra_model_kwargs['functions'] = [{ - # "name": tool.name, - # "description": tool.description, - # "parameters": tool.parameters - # } for tool in tools] if stop: extra_model_kwargs["stop"] = stop @@ -769,3 +756,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ai_model_entity_copy.entity.label.en_US = model ai_model_entity_copy.entity.label.zh_Hans = model return ai_model_entity_copy + + def _get_base_model_name(self, credentials: dict) -> str: + base_model_name = credentials.get("base_model_name") + if not base_model_name: + raise ValueError("Base Model Name is required") + return base_model_name diff --git a/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo-full.svg b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo-full.svg new file mode 100644 index 0000000000..f9738b585b --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo-full.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo.svg b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo.svg new file mode 100644 index 0000000000..1f51187f19 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/model_runtime/model_providers/gitee_ai/_common.py b/api/core/model_runtime/model_providers/gitee_ai/_common.py new file mode 100644 index 0000000000..0750f3b75d --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/_common.py @@ -0,0 +1,47 @@ +from dashscope.common.error import ( + AuthenticationError, + InvalidParameter, + RequestFailure, + ServiceUnavailableError, + UnsupportedHTTPMethod, + UnsupportedModel, +) + +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +class _CommonGiteeAI: + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + RequestFailure, + ], + InvokeServerUnavailableError: [ + ServiceUnavailableError, + ], + InvokeRateLimitError: [], + InvokeAuthorizationError: [ + AuthenticationError, + ], + InvokeBadRequestError: [ + InvalidParameter, + UnsupportedModel, + UnsupportedHTTPMethod, + ], + } diff --git a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py new file mode 100644 index 0000000000..ca67594ce4 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py @@ -0,0 +1,25 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class GiteeAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.yaml b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.yaml new file mode 100644 index 0000000000..7f7d0f2e53 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.yaml @@ -0,0 +1,35 @@ +provider: gitee_ai +label: + en_US: Gitee AI + zh_Hans: Gitee AI +description: + en_US: 快速体验大模型,领先探索 AI 开源世界 + zh_Hans: 快速体验大模型,领先探索 AI 开源世界 +icon_small: + en_US: Gitee-AI-Logo.svg +icon_large: + en_US: Gitee-AI-Logo-full.svg +help: + title: + en_US: Get your token from Gitee AI + zh_Hans: 从 Gitee AI 获取 token + url: + en_US: https://ai.gitee.com/dashboard/settings/tokens +supported_model_types: + - llm + - text-embedding + - rerank + - speech2text + - tts +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-72B-Instruct.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-72B-Instruct.yaml new file mode 100644 index 0000000000..0348438a75 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-72B-Instruct.yaml @@ -0,0 +1,105 @@ +model: Qwen2-72B-Instruct +label: + zh_Hans: Qwen2-72B-Instruct + en_US: Qwen2-72B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 6400 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-7B-Instruct.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-7B-Instruct.yaml new file mode 100644 index 0000000000..ba1ad788f5 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-7B-Instruct.yaml @@ -0,0 +1,105 @@ +model: Qwen2-7B-Instruct +label: + zh_Hans: Qwen2-7B-Instruct + en_US: Qwen2-7B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/Yi-1.5-34B-Chat.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/Yi-1.5-34B-Chat.yaml new file mode 100644 index 0000000000..f7260c987b --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/Yi-1.5-34B-Chat.yaml @@ -0,0 +1,105 @@ +model: Yi-1.5-34B-Chat +label: + zh_Hans: Yi-1.5-34B-Chat + en_US: Yi-1.5-34B-Chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml new file mode 100644 index 0000000000..21f6120742 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml @@ -0,0 +1,7 @@ +- Qwen2-7B-Instruct +- Qwen2-72B-Instruct +- Yi-1.5-34B-Chat +- glm-4-9b-chat +- deepseek-coder-33B-instruct-chat +- deepseek-coder-33B-instruct-completions +- codegeex4-all-9b diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/codegeex4-all-9b.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/codegeex4-all-9b.yaml new file mode 100644 index 0000000000..8632cd92ab --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/codegeex4-all-9b.yaml @@ -0,0 +1,105 @@ +model: codegeex4-all-9b +label: + zh_Hans: codegeex4-all-9b + en_US: codegeex4-all-9b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 40960 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-chat.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-chat.yaml new file mode 100644 index 0000000000..2ac00761d5 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-chat.yaml @@ -0,0 +1,105 @@ +model: deepseek-coder-33B-instruct-chat +label: + zh_Hans: deepseek-coder-33B-instruct-chat + en_US: deepseek-coder-33B-instruct-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 9000 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-completions.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-completions.yaml new file mode 100644 index 0000000000..7c364d89f7 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-completions.yaml @@ -0,0 +1,91 @@ +model: deepseek-coder-33B-instruct-completions +label: + zh_Hans: deepseek-coder-33B-instruct-completions + en_US: deepseek-coder-33B-instruct-completions +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 9000 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/glm-4-9b-chat.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/glm-4-9b-chat.yaml new file mode 100644 index 0000000000..2afe1cf959 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/glm-4-9b-chat.yaml @@ -0,0 +1,105 @@ +model: glm-4-9b-chat +label: + zh_Hans: glm-4-9b-chat + en_US: glm-4-9b-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py b/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py new file mode 100644 index 0000000000..b65db6f665 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py @@ -0,0 +1,47 @@ +from collections.abc import Generator +from typing import Optional, Union + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + MODEL_TO_IDENTITY: dict[str, str] = { + "Yi-1.5-34B-Chat": "Yi-34B-Chat", + "deepseek-coder-33B-instruct-completions": "deepseek-coder-33B-instruct", + "deepseek-coder-33B-instruct-chat": "deepseek-coder-33B-instruct", + } + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials, model, model_parameters) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials, model, None) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict, model: str, model_parameters: dict) -> None: + if model is None: + model = "bge-large-zh-v1.5" + + model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model) + credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/" + if model.endswith("completions"): + credentials["mode"] = LLMMode.COMPLETION.value + else: + credentials["mode"] = LLMMode.CHAT.value diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/__init__.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/rerank/_position.yaml new file mode 100644 index 0000000000..83162fd338 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/_position.yaml @@ -0,0 +1 @@ +- bge-reranker-v2-m3 diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/bge-reranker-v2-m3.yaml b/api/core/model_runtime/model_providers/gitee_ai/rerank/bge-reranker-v2-m3.yaml new file mode 100644 index 0000000000..f0681641e1 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/bge-reranker-v2-m3.yaml @@ -0,0 +1,4 @@ +model: bge-reranker-v2-m3 +model_type: rerank +model_properties: + context_size: 1024 diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py new file mode 100644 index 0000000000..231345c2f4 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py @@ -0,0 +1,128 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class GiteeAIRerankModel(RerankModel): + """ + Model class for rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + base_url = credentials.get("base_url", "https://ai.gitee.com/api/serverless") + base_url = base_url.removesuffix("/") + + try: + body = {"model": model, "query": query, "documents": docs} + if top_n is not None: + body["top_n"] = top_n + response = httpx.post( + f"{base_url}/{model}/rerank", + json=body, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, + ) + + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results["results"]: + rerank_document = RerankDocument( + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], + ) + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.01, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/__init__.py b/api/core/model_runtime/model_providers/gitee_ai/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/speech2text/_position.yaml new file mode 100644 index 0000000000..8e9b47598b --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/_position.yaml @@ -0,0 +1,2 @@ +- whisper-base +- whisper-large diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/gitee_ai/speech2text/speech2text.py new file mode 100644 index 0000000000..5597f5b43e --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/speech2text.py @@ -0,0 +1,53 @@ +import os +from typing import IO, Optional + +import requests + +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI + + +class GiteeAISpeech2TextModel(_CommonGiteeAI, Speech2TextModel): + """ + Model class for OpenAI Compatible Speech to text model. + """ + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + # doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/speech-to-text + + endpoint_url = f"https://ai.gitee.com/api/serverless/{model}/speech-to-text" + files = [("file", file)] + _, file_ext = os.path.splitext(file.name) + headers = {"Content-Type": f"audio/{file_ext}", "Authorization": f"Bearer {credentials.get('api_key')}"} + response = requests.post(endpoint_url, headers=headers, files=files) + if response.status_code != 200: + raise InvokeBadRequestError(response.text) + response_data = response.json() + return response_data["text"] + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + audio_file_path = self._get_demo_file_path() + + with open(audio_file_path, "rb") as audio_file: + self._invoke(model, credentials, audio_file) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-base.yaml b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-base.yaml new file mode 100644 index 0000000000..a50bf5fc2d --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-base.yaml @@ -0,0 +1,5 @@ +model: whisper-base +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-large.yaml b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-large.yaml new file mode 100644 index 0000000000..1be7b1a391 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-large.yaml @@ -0,0 +1,5 @@ +model: whisper-large +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/_position.yaml new file mode 100644 index 0000000000..e8abe6440d --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/_position.yaml @@ -0,0 +1,3 @@ +- bge-large-zh-v1.5 +- bge-small-zh-v1.5 +- bge-m3 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-large-zh-v1.5.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-large-zh-v1.5.yaml new file mode 100644 index 0000000000..9e3ca76e88 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-large-zh-v1.5.yaml @@ -0,0 +1,8 @@ +model: bge-large-zh-v1.5 +label: + zh_Hans: bge-large-zh-v1.5 + en_US: bge-large-zh-v1.5 +model_type: text-embedding +model_properties: + context_size: 200000 + max_chunks: 20 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-m3.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-m3.yaml new file mode 100644 index 0000000000..a7a99a98a3 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-m3.yaml @@ -0,0 +1,8 @@ +model: bge-m3 +label: + zh_Hans: bge-m3 + en_US: bge-m3 +model_type: text-embedding +model_properties: + context_size: 200000 + max_chunks: 20 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-small-zh-v1.5.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-small-zh-v1.5.yaml new file mode 100644 index 0000000000..bd760408fa --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-small-zh-v1.5.yaml @@ -0,0 +1,8 @@ +model: bge-small-zh-v1.5 +label: + zh_Hans: bge-small-zh-v1.5 + en_US: bge-small-zh-v1.5 +model_type: text-embedding +model_properties: + context_size: 200000 + max_chunks: 20 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py new file mode 100644 index 0000000000..b833c5652c --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py @@ -0,0 +1,31 @@ +from typing import Optional + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( + OAICompatEmbeddingModel, +) + + +class GiteeAIEmbeddingModel(OAICompatEmbeddingModel): + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + self._add_custom_parameters(credentials, model) + return super()._invoke(model, credentials, texts, user, input_type) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials, None) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict, model: str) -> None: + if model is None: + model = "bge-m3" + + credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model}/v1/" diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/ChatTTS.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/ChatTTS.yaml new file mode 100644 index 0000000000..940391dfab --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/ChatTTS.yaml @@ -0,0 +1,11 @@ +model: ChatTTS +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/FunAudioLLM-CosyVoice-300M.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/FunAudioLLM-CosyVoice-300M.yaml new file mode 100644 index 0000000000..8fc5734801 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/FunAudioLLM-CosyVoice-300M.yaml @@ -0,0 +1,11 @@ +model: FunAudioLLM-CosyVoice-300M +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/__init__.py b/api/core/model_runtime/model_providers/gitee_ai/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/_position.yaml new file mode 100644 index 0000000000..13c6ec8454 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/_position.yaml @@ -0,0 +1,4 @@ +- speecht5_tts +- ChatTTS +- fish-speech-1.2-sft +- FunAudioLLM-CosyVoice-300M diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/fish-speech-1.2-sft.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/fish-speech-1.2-sft.yaml new file mode 100644 index 0000000000..93cc28bc9d --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/fish-speech-1.2-sft.yaml @@ -0,0 +1,11 @@ +model: fish-speech-1.2-sft +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/speecht5_tts.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/speecht5_tts.yaml new file mode 100644 index 0000000000..f9c843bd41 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/speecht5_tts.yaml @@ -0,0 +1,11 @@ +model: speecht5_tts +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py new file mode 100644 index 0000000000..ed2bd5b13d --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py @@ -0,0 +1,79 @@ +from typing import Optional + +import requests + +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel +from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI + + +class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel): + """ + Model class for OpenAI Speech to text model. + """ + + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :param user: unique user id + :return: text translated to audio file + """ + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + validate credentials text2speech model + + :param model: model name + :param credentials: model credentials + :return: text translated to audio file + """ + try: + self._tts_invoke_streaming( + model=model, + credentials=credentials, + content_text="Hello Dify!", + voice=self._get_model_default_voice(model, credentials), + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + """ + _tts_invoke_streaming text2speech model + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + try: + # doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/text-to-speech + endpoint_url = "https://ai.gitee.com/api/serverless/" + model + "/text-to-speech" + + headers = {"Content-Type": "application/json"} + api_key = credentials.get("api_key") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + payload = {"inputs": content_text} + response = requests.post(endpoint_url, headers=headers, json=payload) + + if response.status_code != 200: + raise InvokeBadRequestError(response.text) + + data = response.content + + for i in range(0, len(data), 1024): + yield data[i : i + 1024] + except Exception as ex: + raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index e686ad08d9..b1b07a611b 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -116,26 +116,33 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :param tools: tool messages :return: glm tools """ - return glm.Tool( - function_declarations=[ - glm.FunctionDeclaration( - name=tool.name, - parameters=glm.Schema( - type=glm.Type.OBJECT, - properties={ - key: { - "type_": value.get("type", "string").upper(), - "description": value.get("description", ""), - "enum": value.get("enum", []), - } - for key, value in tool.parameters.get("properties", {}).items() - }, - required=tool.parameters.get("required", []), - ), + function_declarations = [] + for tool in tools: + properties = {} + for key, value in tool.parameters.get("properties", {}).items(): + properties[key] = { + "type_": glm.Type.STRING, + "description": value.get("description", ""), + "enum": value.get("enum", []), + } + + if properties: + parameters = glm.Schema( + type=glm.Type.OBJECT, + properties=properties, + required=tool.parameters.get("required", []), ) - for tool in tools - ] - ) + else: + parameters = None + + function_declaration = glm.FunctionDeclaration( + name=tool.name, + parameters=parameters, + description=tool.description, + ) + function_declarations.append(function_declaration) + + return glm.Tool(function_declarations=function_declarations) def validate_credentials(self, model: str, credentials: dict) -> None: """ diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 01a2a07325..5c955c86d3 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -44,6 +44,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): self._add_custom_parameters(credentials) self._add_function_call(model, credentials) user = user[:32] if user else None + # {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}} + if "response_format" in model_parameters: + model_parameters["response_format"] = {"type": model_parameters.get("response_format")} return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 356ac56b1e..e1342fe985 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -397,16 +397,21 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): chunk_index = 0 def create_final_llm_result_chunk( - index: int, message: AssistantPromptMessage, finish_reason: str + id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict ) -> LLMResultChunk: # calculate num tokens - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) - completion_tokens = self._num_tokens_from_string(model, full_assistant_content) + prompt_tokens = usage and usage.get("prompt_tokens") + if prompt_tokens is None: + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + completion_tokens = usage and usage.get("completion_tokens") + if completion_tokens is None: + completion_tokens = self._num_tokens_from_string(model, full_assistant_content) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) return LLMResultChunk( + id=id, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), @@ -450,7 +455,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): tool_call.function.arguments += new_tool_call.function.arguments finish_reason = None # The default value of finish_reason is None - + message_id, usage = None, None for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): chunk = chunk.strip() if chunk: @@ -462,20 +467,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): continue try: - chunk_json = json.loads(decoded_chunk) + chunk_json: dict = json.loads(decoded_chunk) # stream ended except json.JSONDecodeError as e: yield create_final_llm_result_chunk( + id=message_id, index=chunk_index + 1, message=AssistantPromptMessage(content=""), finish_reason="Non-JSON encountered.", + usage=usage, ) break + if chunk_json: + if u := chunk_json.get("usage"): + usage = u if not chunk_json or len(chunk_json["choices"]) == 0: continue choice = chunk_json["choices"][0] finish_reason = chunk_json["choices"][0].get("finish_reason") + message_id = chunk_json.get("id") chunk_index += 1 if "delta" in choice: @@ -524,6 +535,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): continue yield LLMResultChunk( + id=message_id, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( @@ -536,6 +548,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): if tools_calls: yield LLMResultChunk( + id=message_id, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( @@ -545,17 +558,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): ) yield create_final_llm_result_chunk( - index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + id=message_id, + index=chunk_index, + message=AssistantPromptMessage(content=""), + finish_reason=finish_reason, + usage=usage, ) def _handle_generate_response( self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] ) -> LLMResult: - response_json = response.json() + response_json: dict = response.json() completion_type = LLMMode.value_of(credentials["mode"]) output = response_json["choices"][0] + message_id = response_json.get("id") response_content = "" tool_calls = None @@ -593,6 +611,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): # transform response result = LLMResult( + id=message_id, model=response_json["model"], prompt_messages=prompt_messages, message=assistant_message, diff --git a/api/core/rag/datasource/vdb/couchbase/__init__.py b/api/core/rag/datasource/vdb/couchbase/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py new file mode 100644 index 0000000000..3f88d2ca2b --- /dev/null +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -0,0 +1,378 @@ +import json +import logging +import time +import uuid +from datetime import timedelta +from typing import Any + +from couchbase import search +from couchbase.auth import PasswordAuthenticator +from couchbase.cluster import Cluster +from couchbase.management.search import SearchIndex + +# needed for options -- cluster, timeout, SQL++ (N1QL) query, etc. +from couchbase.options import ClusterOptions, SearchOptions +from couchbase.vector_search import VectorQuery, VectorSearch +from flask import current_app +from pydantic import BaseModel, model_validator + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +class CouchbaseConfig(BaseModel): + connection_string: str + user: str + password: str + bucket_name: str + scope_name: str + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values.get("connection_string"): + raise ValueError("config COUCHBASE_CONNECTION_STRING is required") + if not values.get("user"): + raise ValueError("config COUCHBASE_USER is required") + if not values.get("password"): + raise ValueError("config COUCHBASE_PASSWORD is required") + if not values.get("bucket_name"): + raise ValueError("config COUCHBASE_PASSWORD is required") + if not values.get("scope_name"): + raise ValueError("config COUCHBASE_SCOPE_NAME is required") + return values + + +class CouchbaseVector(BaseVector): + def __init__(self, collection_name: str, config: CouchbaseConfig): + super().__init__(collection_name) + self._client_config = config + + """Connect to couchbase""" + + auth = PasswordAuthenticator(config.user, config.password) + options = ClusterOptions(auth) + self._cluster = Cluster(config.connection_string, options) + self._bucket = self._cluster.bucket(config.bucket_name) + self._scope = self._bucket.scope(config.scope_name) + self._bucket_name = config.bucket_name + self._scope_name = config.scope_name + + # Wait until the cluster is ready for use. + self._cluster.wait_until_ready(timedelta(seconds=5)) + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + index_id = str(uuid.uuid4()).replace("-", "") + self._create_collection(uuid=index_id, vector_length=len(embeddings[0])) + self.add_texts(texts, embeddings) + + def _create_collection(self, vector_length: int, uuid: str): + lock_name = "vector_indexing_lock_{}".format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + if self._collection_exists(self._collection_name): + return + manager = self._bucket.collections() + manager.create_collection(self._client_config.scope_name, self._collection_name) + + index_manager = self._scope.search_indexes() + + index_definition = json.loads(""" +{ + "type": "fulltext-index", + "name": "Embeddings._default.Vector_Search", + "uuid": "26d4db528e78b716", + "sourceType": "gocbcore", + "sourceName": "Embeddings", + "sourceUUID": "2242e4a25b4decd6650c9c7b3afa1dbf", + "planParams": { + "maxPartitionsPerPIndex": 1024, + "indexPartitions": 1 + }, + "params": { + "doc_config": { + "docid_prefix_delim": "", + "docid_regexp": "", + "mode": "scope.collection.type_field", + "type_field": "type" + }, + "mapping": { + "analysis": { }, + "default_analyzer": "standard", + "default_datetime_parser": "dateTimeOptional", + "default_field": "_all", + "default_mapping": { + "dynamic": true, + "enabled": true + }, + "default_type": "_default", + "docvalues_dynamic": false, + "index_dynamic": true, + "store_dynamic": true, + "type_field": "_type", + "types": { + "collection_name": { + "dynamic": true, + "enabled": true, + "properties": { + "embedding": { + "dynamic": false, + "enabled": true, + "fields": [ + { + "dims": 1536, + "index": true, + "name": "embedding", + "similarity": "dot_product", + "type": "vector", + "vector_index_optimized_for": "recall" + } + ] + }, + "metadata": { + "dynamic": true, + "enabled": true + }, + "text": { + "dynamic": false, + "enabled": true, + "fields": [ + { + "index": true, + "name": "text", + "store": true, + "type": "text" + } + ] + } + } + } + } + }, + "store": { + "indexType": "scorch", + "segmentVersion": 16 + } + }, + "sourceParams": { } + } +""") + index_definition["name"] = self._collection_name + "_search" + index_definition["uuid"] = uuid + index_definition["params"]["mapping"]["types"]["collection_name"]["properties"]["embedding"]["fields"][0][ + "dims" + ] = vector_length + index_definition["params"]["mapping"]["types"][self._scope_name + "." + self._collection_name] = ( + index_definition["params"]["mapping"]["types"].pop("collection_name") + ) + time.sleep(2) + index_manager.upsert_index( + SearchIndex( + index_definition["name"], + params=index_definition["params"], + source_name=self._bucket_name, + ), + ) + time.sleep(1) + + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def _collection_exists(self, name: str): + scope_collection_map: dict[str, Any] = {} + + # Get a list of all scopes in the bucket + for scope in self._bucket.collections().get_all_scopes(): + scope_collection_map[scope.name] = [] + + # Get a list of all the collections in the scope + for collection in scope.collections: + scope_collection_map[scope.name].append(collection.name) + + # Check if the collection exists in the scope + return self._collection_name in scope_collection_map[self._scope_name] + + def get_type(self) -> str: + return VectorType.COUCHBASE + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + doc_ids = [] + + documents_to_insert = [ + {"text": text, "embedding": vector, "metadata": metadata} + for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas) + ] + for doc, id in zip(documents_to_insert, uuids): + result = self._scope.collection(self._collection_name).upsert(id, doc) + + doc_ids.extend(uuids) + + return doc_ids + + def text_exists(self, id: str) -> bool: + # Use a parameterized query for safety and correctness + query = f""" + SELECT COUNT(1) AS count FROM + `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} + WHERE META().id = $doc_id + """ + # Pass the id as a parameter to the query + result = self._cluster.query(query, named_parameters={"doc_id": id}).execute() + for row in result: + return row["count"] > 0 + return False # Return False if no rows are returned + + def delete_by_ids(self, ids: list[str]) -> None: + query = f""" + DELETE FROM `{self._bucket_name}`.{self._client_config.scope_name}.{self._collection_name} + WHERE META().id IN $doc_ids; + """ + try: + self._cluster.query(query, named_parameters={"doc_ids": ids}).execute() + except Exception as e: + logger.error(e) + + def delete_by_document_id(self, document_id: str): + query = f""" + DELETE FROM + `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} + WHERE META().id = $doc_id; + """ + self._cluster.query(query, named_parameters={"doc_id": document_id}).execute() + + # def get_ids_by_metadata_field(self, key: str, value: str): + # query = f""" + # SELECT id FROM + # `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} + # WHERE `metadata.{key}` = $value; + # """ + # result = self._cluster.query(query, named_parameters={'value':value}) + # return [row['id'] for row in result.rows()] + + def delete_by_metadata_field(self, key: str, value: str) -> None: + query = f""" + DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} + WHERE metadata.{key} = $value; + """ + self._cluster.query(query, named_parameters={"value": value}).execute() + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 5) + score_threshold = kwargs.get("score_threshold") or 0.0 + + search_req = search.SearchRequest.create( + VectorSearch.from_vector_query( + VectorQuery( + "embedding", + query_vector, + top_k, + ) + ) + ) + try: + search_iter = self._scope.search( + self._collection_name + "_search", + search_req, + SearchOptions(limit=top_k, collections=[self._collection_name], fields=["*"]), + ) + + docs = [] + # Parse the results + for row in search_iter.rows(): + text = row.fields.pop("text") + metadata = self._format_metadata(row.fields) + score = row.score + metadata["score"] = score + doc = Document(page_content=text, metadata=metadata) + if score >= score_threshold: + docs.append(doc) + except Exception as e: + raise ValueError(f"Search failed with error: {e}") + + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 2) + try: + CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) + search_iter = self._scope.search( + self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"]) + ) + + docs = [] + for row in search_iter.rows(): + text = row.fields.pop("text") + metadata = self._format_metadata(row.fields) + score = row.score + metadata["score"] = score + doc = Document(page_content=text, metadata=metadata) + docs.append(doc) + + except Exception as e: + raise ValueError(f"Search failed with error: {e}") + + return docs + + def delete(self): + manager = self._bucket.collections() + scopes = manager.get_all_scopes() + + for scope in scopes: + for collection in scope.collections: + if collection.name == self._collection_name: + manager.drop_collection("_default", self._collection_name) + + def _format_metadata(self, row_fields: dict[str, Any]) -> dict[str, Any]: + """Helper method to format the metadata from the Couchbase Search API. + Args: + row_fields (Dict[str, Any]): The fields to format. + + Returns: + Dict[str, Any]: The formatted metadata. + """ + metadata = {} + for key, value in row_fields.items(): + # Couchbase Search returns the metadata key with a prefix + # `metadata.` We remove it to get the original metadata key + if key.startswith("metadata"): + new_key = key.split("metadata" + ".")[-1] + metadata[new_key] = value + else: + metadata[key] = value + + return metadata + + +class CouchbaseVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> CouchbaseVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.COUCHBASE, collection_name)) + + config = current_app.config + return CouchbaseVector( + collection_name=collection_name, + config=CouchbaseConfig( + connection_string=config.get("COUCHBASE_CONNECTION_STRING"), + user=config.get("COUCHBASE_USER"), + password=config.get("COUCHBASE_PASSWORD"), + bucket_name=config.get("COUCHBASE_BUCKET_NAME"), + scope_name=config.get("COUCHBASE_SCOPE_NAME"), + ), + ) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 052a187225..c62042af80 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -142,7 +142,7 @@ class ElasticSearchVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: query_str = {"match": {Field.CONTENT_KEY.value: query}} - results = self._client.search(index=self._collection_name, query=query_str) + results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) docs = [] for hit in results["hits"]["hits"]: docs.append( diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py new file mode 100644 index 0000000000..1e62b3c589 --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py @@ -0,0 +1,17 @@ +from typing import Optional + +from pydantic import BaseModel + + +class ClusterEntity(BaseModel): + """ + Model Config Entity. + """ + + name: str + cluster_id: str + displayName: str + region: str + spendingLimit: Optional[int] = 1000 + version: str + createdBy: str diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py new file mode 100644 index 0000000000..a38f84636e --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -0,0 +1,526 @@ +import json +import os +import uuid +from collections.abc import Generator, Iterable, Sequence +from itertools import islice +from typing import TYPE_CHECKING, Any, Optional, Union, cast + +import qdrant_client +import requests +from flask import current_app +from pydantic import BaseModel +from qdrant_client.http import models as rest +from qdrant_client.http.models import ( + FilterSelector, + HnswConfigDiff, + PayloadSchemaType, + TextIndexParams, + TextIndexType, + TokenizerType, +) +from qdrant_client.local.qdrant_local import QdrantLocal +from requests.auth import HTTPDigestAuth + +from configs import dify_config +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, TidbAuthBinding + +if TYPE_CHECKING: + from qdrant_client import grpc # noqa + from qdrant_client.conversions import common_types + from qdrant_client.http import models as rest + + DictFilter = dict[str, Union[str, int, bool, dict, list]] + MetadataFilter = Union[DictFilter, common_types.Filter] + + +class TidbOnQdrantConfig(BaseModel): + endpoint: str + api_key: Optional[str] = None + timeout: float = 20 + root_path: Optional[str] = None + grpc_port: int = 6334 + prefer_grpc: bool = False + + def to_qdrant_params(self): + if self.endpoint and self.endpoint.startswith("path:"): + path = self.endpoint.replace("path:", "") + if not os.path.isabs(path): + path = os.path.join(self.root_path, path) + + return {"path": path} + else: + return { + "url": self.endpoint, + "api_key": self.api_key, + "timeout": self.timeout, + "verify": False, + "grpc_port": self.grpc_port, + "prefer_grpc": self.prefer_grpc, + } + + +class TidbConfig(BaseModel): + api_url: str + public_key: str + private_key: str + + +class TidbOnQdrantVector(BaseVector): + def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"): + super().__init__(collection_name) + self._client_config = config + self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) + self._distance_func = distance_func.upper() + self._group_id = group_id + + def get_type(self) -> str: + return VectorType.TIDB_ON_QDRANT + + def to_index_struct(self) -> dict: + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + if texts: + # get embedding vector size + vector_size = len(embeddings[0]) + # get collection name + collection_name = self._collection_name + # create collection + self.create_collection(collection_name, vector_size) + + self.add_texts(texts, embeddings, **kwargs) + + def create_collection(self, collection_name: str, vector_size: int): + lock_name = "vector_indexing_lock_{}".format(collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + collection_name = collection_name or uuid.uuid4().hex + all_collection_name = [] + collections_response = self._client.get_collections() + collection_list = collections_response.collections + for collection in collection_list: + all_collection_name.append(collection.name) + if collection_name not in all_collection_name: + from qdrant_client.http import models as rest + + vectors_config = rest.VectorParams( + size=vector_size, + distance=rest.Distance[self._distance_func], + ) + hnsw_config = HnswConfigDiff( + m=0, + payload_m=16, + ef_construct=100, + full_scan_threshold=10000, + max_indexing_threads=0, + on_disk=False, + ) + self._client.recreate_collection( + collection_name=collection_name, + vectors_config=vectors_config, + hnsw_config=hnsw_config, + timeout=int(self._client_config.timeout), + ) + + # create group_id payload index + self._client.create_payload_index( + collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + ) + # create doc_id payload index + self._client.create_payload_index( + collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD + ) + # create full text index + text_index_params = TextIndexParams( + type=TextIndexType.TEXT, + tokenizer=TokenizerType.MULTILINGUAL, + min_token_len=2, + max_token_len=20, + lowercase=True, + ) + self._client.create_payload_index( + collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params + ) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + added_ids = [] + for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + self._client.upsert(collection_name=self._collection_name, points=points) + added_ids.extend(batch_ids) + + return added_ids + + def _generate_rest_batches( + self, + texts: Iterable[str], + embeddings: list[list[float]], + metadatas: Optional[list[dict]] = None, + ids: Optional[Sequence[str]] = None, + batch_size: int = 64, + group_id: Optional[str] = None, + ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: + from qdrant_client.http import models as rest + + texts_iterator = iter(texts) + embeddings_iterator = iter(embeddings) + metadatas_iterator = iter(metadatas or []) + ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) + while batch_texts := list(islice(texts_iterator, batch_size)): + # Take the corresponding metadata and id for each text in a batch + batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None + batch_ids = list(islice(ids_iterator, batch_size)) + + # Generate the embeddings for all the texts in a batch + batch_embeddings = list(islice(embeddings_iterator, batch_size)) + + points = [ + rest.PointStruct( + id=point_id, + vector=vector, + payload=payload, + ) + for point_id, vector, payload in zip( + batch_ids, + batch_embeddings, + self._build_payloads( + batch_texts, + batch_metadatas, + Field.CONTENT_KEY.value, + Field.METADATA_KEY.value, + group_id, + Field.GROUP_KEY.value, + ), + ) + ] + + yield batch_ids, points + + @classmethod + def _build_payloads( + cls, + texts: Iterable[str], + metadatas: Optional[list[dict]], + content_payload_key: str, + metadata_payload_key: str, + group_id: str, + group_payload_key: str, + ) -> list[dict]: + payloads = [] + for i, text in enumerate(texts): + if text is None: + raise ValueError( + "At least one of the texts is None. Please remove it before " + "calling .from_texts or .add_texts on Qdrant instance." + ) + metadata = metadatas[i] if metadatas is not None else None + payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id}) + + return payloads + + def delete_by_metadata_field(self, key: str, value: str): + from qdrant_client.http import models + from qdrant_client.http.exceptions import UnexpectedResponse + + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue(value=value), + ), + ], + ) + + self._reload_if_needed() + + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e + + def delete(self): + from qdrant_client.http.exceptions import UnexpectedResponse + + try: + self._client.delete_collection(collection_name=self._collection_name) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e + + def delete_by_ids(self, ids: list[str]) -> None: + from qdrant_client.http import models + from qdrant_client.http.exceptions import UnexpectedResponse + + for node_id in ids: + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchValue(value=node_id), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e + + def text_exists(self, id: str) -> bool: + all_collection_name = [] + collections_response = self._client.get_collections() + collection_list = collections_response.collections + for collection in collection_list: + all_collection_name.append(collection.name) + if self._collection_name not in all_collection_name: + return False + response = self._client.retrieve(collection_name=self._collection_name, ids=[id]) + + return len(response) > 0 + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + from qdrant_client.http import models + + filter = models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + ], + ) + results = self._client.search( + collection_name=self._collection_name, + query_vector=query_vector, + query_filter=filter, + limit=kwargs.get("top_k", 4), + with_payload=True, + with_vectors=True, + score_threshold=kwargs.get("score_threshold", 0.0), + ) + docs = [] + for result in results: + metadata = result.payload.get(Field.METADATA_KEY.value) or {} + # duplicate check score threshold + score_threshold = kwargs.get("score_threshold") or 0.0 + if result.score > score_threshold: + metadata["score"] = result.score + doc = Document( + page_content=result.payload.get(Field.CONTENT_KEY.value), + metadata=metadata, + ) + docs.append(doc) + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Return docs most similar by bm25. + Returns: + List of documents most similar to the query text and distance for each. + """ + from qdrant_client.http import models + + scroll_filter = models.Filter( + must=[ + models.FieldCondition( + key="page_content", + match=models.MatchText(text=query), + ) + ] + ) + response = self._client.scroll( + collection_name=self._collection_name, + scroll_filter=scroll_filter, + limit=kwargs.get("top_k", 2), + with_payload=True, + with_vectors=True, + ) + results = response[0] + documents = [] + for result in results: + if result: + document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) + document.metadata["vector"] = result.vector + documents.append(document) + + return documents + + def _reload_if_needed(self): + if isinstance(self._client, QdrantLocal): + self._client = cast(QdrantLocal, self._client) + self._client._load() + + @classmethod + def _document_from_scored_point( + cls, + scored_point: Any, + content_payload_key: str, + metadata_payload_key: str, + ) -> Document: + return Document( + page_content=scored_point.payload.get(content_payload_key), + metadata=scored_point.payload.get(metadata_payload_key) or {}, + ) + + +class TidbOnQdrantVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: + tidb_auth_binding = ( + db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() + ) + if not tidb_auth_binding: + idle_tidb_auth_binding = ( + db.session.query(TidbAuthBinding) + .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") + .limit(1) + .one_or_none() + ) + if idle_tidb_auth_binding: + idle_tidb_auth_binding.active = True + idle_tidb_auth_binding.tenant_id = dataset.tenant_id + db.session.commit() + TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" + else: + with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): + tidb_auth_binding = ( + db.session.query(TidbAuthBinding) + .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) + .one_or_none() + ) + if tidb_auth_binding: + TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" + + else: + new_cluster = TidbService.create_tidb_serverless_cluster( + dify_config.TIDB_PROJECT_ID, + dify_config.TIDB_API_URL, + dify_config.TIDB_IAM_API_URL, + dify_config.TIDB_PUBLIC_KEY, + dify_config.TIDB_PRIVATE_KEY, + dify_config.TIDB_REGION, + ) + new_tidb_auth_binding = TidbAuthBinding( + cluster_id=new_cluster["cluster_id"], + cluster_name=new_cluster["cluster_name"], + account=new_cluster["account"], + password=new_cluster["password"], + tenant_id=dataset.tenant_id, + active=True, + status="ACTIVE", + ) + db.session.add(new_tidb_auth_binding) + db.session.commit() + TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" + + else: + TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" + + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name)) + + config = current_app.config + + return TidbOnQdrantVector( + collection_name=collection_name, + group_id=dataset.id, + config=TidbOnQdrantConfig( + endpoint=dify_config.TIDB_ON_QDRANT_URL, + api_key=TIDB_ON_QDRANT_API_KEY, + root_path=config.root_path, + timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, + grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT, + prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED, + ), + ) + + def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str): + """ + Creates a new TiDB Serverless cluster. + :param tidb_config: The configuration for the TiDB Cloud API. + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + region_object = { + "name": region, + } + + labels = { + "tidb.cloud/project": "1372813089454548012", + } + cluster_data = {"displayName": display_name, "region": region_object, "labels": labels} + + response = requests.post( + f"{tidb_config.api_url}/clusters", + json=cluster_data, + auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), + ) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str): + """ + Changes the root password of a specific TiDB Serverless cluster. + + :param tidb_config: The configuration for the TiDB Cloud API. + :param cluster_id: The ID of the cluster for which the password is to be changed (required). + :param new_password: The new password for the root user (required). + :return: The response from the API. + """ + + body = {"password": new_password} + + response = requests.put( + f"{tidb_config.api_url}/clusters/{cluster_id}/password", + json=body, + auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), + ) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py new file mode 100644 index 0000000000..f10d6339ee --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -0,0 +1,250 @@ +import time +import uuid + +import requests +from requests.auth import HTTPDigestAuth + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import TidbAuthBinding + + +class TidbService: + @staticmethod + def create_tidb_serverless_cluster( + project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str + ): + """ + Creates a new TiDB Serverless cluster. + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + + region_object = { + "name": region, + } + + labels = { + "tidb.cloud/project": project_id, + } + + spending_limit = { + "monthly": 100, + } + password = str(uuid.uuid4()).replace("-", "")[:16] + display_name = str(uuid.uuid4()).replace("-", "")[:16] + cluster_data = { + "displayName": display_name, + "region": region_object, + "labels": labels, + "spendingLimit": spending_limit, + "rootPassword": password, + } + + response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key)) + + if response.status_code == 200: + response_data = response.json() + cluster_id = response_data["clusterId"] + retry_count = 0 + max_retries = 30 + while retry_count < max_retries: + cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) + if cluster_response["state"] == "ACTIVE": + user_prefix = cluster_response["userPrefix"] + return { + "cluster_id": cluster_id, + "cluster_name": display_name, + "account": f"{user_prefix}.root", + "password": password, + } + time.sleep(30) # wait 30 seconds before retrying + retry_count += 1 + else: + response.raise_for_status() + + @staticmethod + def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): + """ + Deletes a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster to be deleted (required). + :return: The response from the API. + """ + + response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): + """ + Deletes a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster to be deleted (required). + :return: The response from the API. + """ + + response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def change_tidb_serverless_root_password( + api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str + ): + """ + Changes the root password of a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster for which the password is to be changed (required).+ + :param account: The account for which the password is to be changed (required). + :param new_password: The new password for the root user (required). + :return: The response from the API. + """ + + body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} + + response = requests.patch( + f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", + json=body, + auth=HTTPDigestAuth(public_key, private_key), + ) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def batch_update_tidb_serverless_cluster_status( + tidb_serverless_list: list[TidbAuthBinding], + project_id: str, + api_url: str, + iam_url: str, + public_key: str, + private_key: str, + ) -> list[dict]: + """ + Update the status of a new TiDB Serverless cluster. + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + clusters = [] + tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} + cluster_ids = [item.cluster_id for item in tidb_serverless_list] + params = {"clusterIds": cluster_ids, "view": "FULL"} + response = requests.get( + f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key) + ) + + if response.status_code == 200: + response_data = response.json() + cluster_infos = [] + for item in response_data["clusters"]: + state = item["state"] + userPrefix = item["userPrefix"] + if state == "ACTIVE" and len(userPrefix) > 0: + cluster_info = tidb_serverless_list_map[item["clusterId"]] + cluster_info.status = "ACTIVE" + cluster_info.account = f"{userPrefix}.root" + db.session.add(cluster_info) + db.session.commit() + else: + response.raise_for_status() + + @staticmethod + def batch_create_tidb_serverless_cluster( + batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str + ) -> list[dict]: + """ + Creates a new TiDB Serverless cluster. + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + clusters = [] + for _ in range(batch_size): + region_object = { + "name": region, + } + + labels = { + "tidb.cloud/project": project_id, + } + + spending_limit = { + "monthly": 10, + } + password = str(uuid.uuid4()).replace("-", "")[:16] + display_name = str(uuid.uuid4()).replace("-", "") + cluster_data = { + "cluster": { + "displayName": display_name, + "region": region_object, + "labels": labels, + "spendingLimit": spending_limit, + "rootPassword": password, + } + } + cache_key = f"tidb_serverless_cluster_password:{display_name}" + redis_client.setex(cache_key, 3600, password) + clusters.append(cluster_data) + + request_body = {"requests": clusters} + response = requests.post( + f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key) + ) + + if response.status_code == 200: + response_data = response.json() + cluster_infos = [] + for item in response_data["clusters"]: + cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" + password = redis_client.get(cache_key) + if not password: + continue + cluster_info = { + "cluster_id": item["clusterId"], + "cluster_name": item["displayName"], + "account": "root", + "password": password.decode("utf-8"), + } + cluster_infos.append(cluster_info) + return cluster_infos + else: + response.raise_for_status() diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 9ea3cf4b6b..87d19bf60b 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -9,8 +9,9 @@ from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document +from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset +from models.dataset import Dataset, Whitelist class AbstractVectorFactory(ABC): @@ -35,8 +36,18 @@ class Vector: def _init_vector(self) -> BaseVector: vector_type = dify_config.VECTOR_STORE + if self._dataset.index_struct_dict: vector_type = self._dataset.index_struct_dict["type"] + else: + if dify_config.VECTOR_STORE_WHITELIST_ENABLE: + whitelist = ( + db.session.query(Whitelist) + .filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") + .one_or_none() + ) + if whitelist: + vector_type = VectorType.TIDB_ON_QDRANT if not vector_type: raise ValueError("Vector store must be specified.") @@ -103,6 +114,10 @@ class Vector: from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory return AnalyticdbVectorFactory + case VectorType.COUCHBASE: + from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory + + return CouchbaseVectorFactory case VectorType.BAIDU: from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory @@ -115,6 +130,10 @@ class Vector: from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory return UpstashVectorFactory + case VectorType.TIDB_ON_QDRANT: + from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory + + return TidbOnQdrantVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 30aa814553..7384c12ff7 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -16,6 +16,8 @@ class VectorType(str, Enum): TENCENT = "tencent" ORACLE = "oracle" ELASTICSEARCH = "elasticsearch" + COUCHBASE = "couchbase" BAIDU = "baidu" VIKINGDB = "vikingdb" UPSTASH = "upstash" + TIDB_ON_QDRANT = "tidb_on_qdrant" diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 603f7555dd..a0b1aa4cef 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -21,7 +21,6 @@ from core.rag.extractor.unstructured.unstructured_eml_extractor import Unstructu from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor -from core.rag.extractor.unstructured.unstructured_pdf_extractor import UnstructuredPDFExtractor from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor @@ -103,7 +102,7 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": - extractor = UnstructuredPDFExtractor(file_path, unstructured_api_url, unstructured_api_key) + extractor = PdfExtractor(file_path) elif file_extension in {".md", ".markdown"}: extractor = ( UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key) @@ -122,6 +121,8 @@ class ExtractProcessor: extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url, unstructured_api_key) elif file_extension == ".ppt": extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key) + # You must first specify the API key + # because unstructured_api_key is necessary to parse .ppt documents elif file_extension == ".pptx": extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url, unstructured_api_key) elif file_extension == ".xml": diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index a5375991b4..ae3c25125c 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -234,7 +234,7 @@ class WordExtractor(BaseExtractor): def parse_paragraph(paragraph): paragraph_content = [] for run in paragraph.runs: - if hasattr(run.element, "tag") and isinstance(element.tag, str) and run.element.tag.endswith("r"): + if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"): drawing_elements = run.element.findall( ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing" ) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 9a31e673d3..d8637fd2cb 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -204,7 +204,7 @@ class ToolParameter(BaseModel): return str(value) except Exception: - raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.") + raise ValueError(f"The tool parameter value {value} is not in correct type.") class ToolParameterForm(Enum): SCHEMA = "schema" # should be set while adding tool diff --git a/api/core/tools/provider/builtin/aliyuque/tools/base.py b/api/core/tools/provider/builtin/aliyuque/tools/base.py index fb7e219bff..edfb9fea8e 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/base.py +++ b/api/core/tools/provider/builtin/aliyuque/tools/base.py @@ -1,10 +1,3 @@ -""" -语雀客户端 -""" - -__author__ = "佐井" -__created__ = "2024-06-01 09:45:20" - from typing import Any import requests @@ -29,14 +22,13 @@ class AliYuqueTool: session = requests.Session() session.headers.update({"accept": "application/json", "X-Auth-Token": token}) new_params = {**tool_parameters} - # 找出需要替换的变量 + replacements = {k: v for k, v in new_params.items() if f"{{{k}}}" in path} - # 替换 path 中的变量 for key, value in replacements.items(): path = path.replace(f"{{{key}}}", str(value)) - del new_params[key] # 从 kwargs 中删除已经替换的变量 - # 请求接口 + del new_params[key] + if method.upper() in {"POST", "PUT"}: session.headers.update( { diff --git a/api/core/tools/provider/builtin/aliyuque/tools/create_document.py b/api/core/tools/provider/builtin/aliyuque/tools/create_document.py index feadc29258..01080fd1d5 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/create_document.py +++ b/api/core/tools/provider/builtin/aliyuque/tools/create_document.py @@ -1,10 +1,3 @@ -""" -创建文档 -""" - -__author__ = "佐井" -__created__ = "2024-06-01 10:45:20" - from typing import Any, Union from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml index b9d1c60327..6ac8ae6696 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml +++ b/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml @@ -13,7 +13,7 @@ description: parameters: - name: book_id - type: number + type: string required: true form: llm label: diff --git a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py index 74c731a944..84237cec30 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py +++ b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py @@ -1,11 +1,3 @@ -#!/usr/bin/env python3 -""" -删除文档 -""" - -__author__ = "佐井" -__created__ = "2024-09-17 22:04" - from typing import Any, Union from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml index 87372c5350..dddd62d304 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml +++ b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml @@ -13,7 +13,7 @@ description: parameters: - name: book_id - type: number + type: string required: true form: llm label: diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py index 02bf603a24..c23d30059a 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py @@ -1,10 +1,3 @@ -""" -获取知识库首页 -""" - -__author__ = "佐井" -__created__ = "2024-06-01 22:57:14" - from typing import Any, Union from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py index fcfe449c6d..36f8c10d6f 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py @@ -1,11 +1,3 @@ -#!/usr/bin/env python3 -""" -获取知识库目录 -""" - -__author__ = "佐井" -__created__ = "2024-09-17 15:17:11" - from typing import Any, Union from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml index 0c2bd22132..0a481b59eb 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml @@ -13,7 +13,7 @@ description: parameters: - name: book_id - type: number + type: string required: true form: llm label: diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py index 1e70593879..a69bf121f7 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py @@ -1,10 +1,3 @@ -""" -获取文档 -""" - -__author__ = "佐井" -__created__ = "2024-06-02 07:11:45" - import json from typing import Any, Union from urllib.parse import urlparse @@ -37,7 +30,6 @@ class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool): book_slug = path_parts[-2] group_id = path_parts[-3] - # 1. 请求首页信息,获取book_id new_params["group_login"] = group_id new_params["book_slug"] = book_slug index_page = json.loads( @@ -46,7 +38,7 @@ class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool): book_id = index_page.get("data", {}).get("book", {}).get("id") if not book_id: raise Exception(f"can not parse book_id from {index_page}") - # 2. 获取文档内容 + new_params["book_id"] = book_id new_params["id"] = doc_id data = self.request("GET", token, new_params, "/api/v2/repos/{book_id}/docs/{id}") diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py index ed1b2a8643..7a45684bed 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py @@ -1,10 +1,3 @@ -""" -获取文档 -""" - -__author__ = "佐井" -__created__ = "2024-06-01 10:45:20" - from typing import Any, Union from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml index 5156345d71..0b14c1afba 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml @@ -14,7 +14,7 @@ description: parameters: - name: book_id - type: number + type: string required: true form: llm label: diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py index 932559445e..ca0a3909f8 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py @@ -1,11 +1,3 @@ -#!/usr/bin/env python3 -""" -获取知识库目录 -""" - -__author__ = "佐井" -__created__ = "2024-09-17 15:17:11" - from typing import Any, Union from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml index f0c0024f17..f85970348b 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml @@ -13,7 +13,7 @@ description: parameters: - name: book_id - type: number + type: string required: true form: llm label: diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_document.py b/api/core/tools/provider/builtin/aliyuque/tools/update_document.py index 0c6e0205e1..d7eba46ad9 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/update_document.py +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_document.py @@ -1,10 +1,3 @@ -""" -更新文档 -""" - -__author__ = "佐井" -__created__ = "2024-06-19 16:50:07" - from typing import Any, Union from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml index 87f88c9b1b..c2da6b179a 100644 --- a/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml @@ -12,7 +12,7 @@ description: llm: Update doc in a knowledge base via ID/path. parameters: - name: book_id - type: number + type: string required: true form: llm label: diff --git a/api/core/tools/provider/builtin/baidu_translate/_assets/icon.png b/api/core/tools/provider/builtin/baidu_translate/_assets/icon.png new file mode 100644 index 0000000000..8eb8f21513 Binary files /dev/null and b/api/core/tools/provider/builtin/baidu_translate/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/baidu_translate/_baidu_translate_tool_base.py b/api/core/tools/provider/builtin/baidu_translate/_baidu_translate_tool_base.py new file mode 100644 index 0000000000..ce907c3c61 --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/_baidu_translate_tool_base.py @@ -0,0 +1,11 @@ +from hashlib import md5 + + +class BaiduTranslateToolBase: + def _get_sign(self, appid, secret, salt, query): + """ + get baidu translate sign + """ + # concatenate the string in the order of appid+q+salt+secret + str = appid + query + salt + secret + return md5(str.encode("utf-8")).hexdigest() diff --git a/api/core/tools/provider/builtin/baidu_translate/baidu_translate.py b/api/core/tools/provider/builtin/baidu_translate/baidu_translate.py new file mode 100644 index 0000000000..cccd2f8c8f --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/baidu_translate.py @@ -0,0 +1,17 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.baidu_translate.tools.translate import BaiduTranslateTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class BaiduTranslateProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + BaiduTranslateTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke(user_id="", tool_parameters={"q": "这是一段测试文本", "from": "auto", "to": "en"}) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/baidu_translate/baidu_translate.yaml b/api/core/tools/provider/builtin/baidu_translate/baidu_translate.yaml new file mode 100644 index 0000000000..06dadeeefc --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/baidu_translate.yaml @@ -0,0 +1,39 @@ +identity: + author: Xiao Ley + name: baidu_translate + label: + en_US: Baidu Translate + zh_Hans: 百度翻译 + description: + en_US: Translate text using Baidu + zh_Hans: 使用百度进行翻译 + icon: icon.png + tags: + - utilities +credentials_for_provider: + appid: + type: secret-input + required: true + label: + en_US: Baidu translate appid + zh_Hans: Baidu translate appid + placeholder: + en_US: Please input your Baidu translate appid + zh_Hans: 请输入你的百度翻译 appid + help: + en_US: Get your Baidu translate appid from Baidu translate + zh_Hans: 从百度翻译开放平台获取你的 appid + url: https://api.fanyi.baidu.com + secret: + type: secret-input + required: true + label: + en_US: Baidu translate secret + zh_Hans: Baidu translate secret + placeholder: + en_US: Please input your Baidu translate secret + zh_Hans: 请输入你的百度翻译 secret + help: + en_US: Get your Baidu translate secret from Baidu translate + zh_Hans: 从百度翻译开放平台获取你的 secret + url: https://api.fanyi.baidu.com diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.py b/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.py new file mode 100644 index 0000000000..bce259f31d --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.py @@ -0,0 +1,78 @@ +import random +from hashlib import md5 +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.baidu_translate._baidu_translate_tool_base import BaiduTranslateToolBase +from core.tools.tool.builtin_tool import BuiltinTool + + +class BaiduFieldTranslateTool(BuiltinTool, BaiduTranslateToolBase): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + BAIDU_FIELD_TRANSLATE_URL = "https://fanyi-api.baidu.com/api/trans/vip/fieldtranslate" + + appid = self.runtime.credentials.get("appid", "") + if not appid: + raise ValueError("invalid baidu translate appid") + + secret = self.runtime.credentials.get("secret", "") + if not secret: + raise ValueError("invalid baidu translate secret") + + q = tool_parameters.get("q", "") + if not q: + raise ValueError("Please input text to translate") + + from_ = tool_parameters.get("from", "") + if not from_: + raise ValueError("Please select source language") + + to = tool_parameters.get("to", "") + if not to: + raise ValueError("Please select destination language") + + domain = tool_parameters.get("domain", "") + if not domain: + raise ValueError("Please select domain") + + salt = str(random.randint(32768, 16777215)) + sign = self._get_sign(appid, secret, salt, q, domain) + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + params = { + "q": q, + "from": from_, + "to": to, + "appid": appid, + "salt": salt, + "domain": domain, + "sign": sign, + "needIntervene": 1, + } + try: + response = requests.post(BAIDU_FIELD_TRANSLATE_URL, headers=headers, data=params) + result = response.json() + + if "trans_result" in result: + result_text = result["trans_result"][0]["dst"] + else: + result_text = f'{result["error_code"]}: {result["error_msg"]}' + + return self.create_text_message(str(result_text)) + except requests.RequestException as e: + raise ValueError(f"Translation service error: {e}") + except Exception: + raise ValueError("Translation service error, please check the network") + + def _get_sign(self, appid, secret, salt, query, domain): + str = appid + query + salt + domain + secret + return md5(str.encode("utf-8")).hexdigest() diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.yaml b/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.yaml new file mode 100644 index 0000000000..de51fddbae --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.yaml @@ -0,0 +1,123 @@ +identity: + name: field_translate + author: Xiao Ley + label: + en_US: Field translate + zh_Hans: 百度领域翻译 +description: + human: + en_US: A tool for Baidu Field translate (Currently, the fields of "novel" and "wiki" only support Chinese to English translation. If the language direction is set to English to Chinese, the default output will be a universal translation result). + zh_Hans: 百度领域翻译,提供多种领域的文本翻译(目前“网络文学领域”和“人文社科领域”仅支持中到英,如设置语言方向为英到中,则默认输出通用翻译结果) + llm: A tool for Baidu Field translate +parameters: + - name: q + type: string + required: true + label: + en_US: Text content + zh_Hans: 文本内容 + human_description: + en_US: Text content to be translated + zh_Hans: 需要翻译的文本内容 + llm_description: Text content to be translated + form: llm + - name: from + type: select + required: true + label: + en_US: source language + zh_Hans: 源语言 + human_description: + en_US: The source language of the input text + zh_Hans: 输入的文本的源语言 + default: auto + form: form + options: + - value: auto + label: + en_US: auto + zh_Hans: 自动检测 + - value: zh + label: + en_US: Chinese + zh_Hans: 中文 + - value: en + label: + en_US: English + zh_Hans: 英语 + - name: to + type: select + required: true + label: + en_US: destination language + zh_Hans: 目标语言 + human_description: + en_US: The destination language of the input text + zh_Hans: 输入文本的目标语言 + default: en + form: form + options: + - value: zh + label: + en_US: Chinese + zh_Hans: 中文 + - value: en + label: + en_US: English + zh_Hans: 英语 + - name: domain + type: select + required: true + label: + en_US: domain + zh_Hans: 领域 + human_description: + en_US: The domain of the input text + zh_Hans: 输入文本的领域 + default: novel + form: form + options: + - value: it + label: + en_US: it + zh_Hans: 信息技术领域 + - value: finance + label: + en_US: finance + zh_Hans: 金融财经领域 + - value: machinery + label: + en_US: machinery + zh_Hans: 机械制造领域 + - value: senimed + label: + en_US: senimed + zh_Hans: 生物医药领域 + - value: novel + label: + en_US: novel (only support Chinese to English translation) + zh_Hans: 网络文学领域(仅支持中到英) + - value: academic + label: + en_US: academic + zh_Hans: 学术论文领域 + - value: aerospace + label: + en_US: aerospace + zh_Hans: 航空航天领域 + - value: wiki + label: + en_US: wiki (only support Chinese to English translation) + zh_Hans: 人文社科领域(仅支持中到英) + - value: news + label: + en_US: news + zh_Hans: 新闻咨询领域 + - value: law + label: + en_US: law + zh_Hans: 法律法规领域 + - value: contract + label: + en_US: contract + zh_Hans: 合同领域 diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/language.py b/api/core/tools/provider/builtin/baidu_translate/tools/language.py new file mode 100644 index 0000000000..3bbaee88b3 --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/language.py @@ -0,0 +1,95 @@ +import random +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.baidu_translate._baidu_translate_tool_base import BaiduTranslateToolBase +from core.tools.tool.builtin_tool import BuiltinTool + + +class BaiduLanguageTool(BuiltinTool, BaiduTranslateToolBase): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + BAIDU_LANGUAGE_URL = "https://fanyi-api.baidu.com/api/trans/vip/language" + + appid = self.runtime.credentials.get("appid", "") + if not appid: + raise ValueError("invalid baidu translate appid") + + secret = self.runtime.credentials.get("secret", "") + if not secret: + raise ValueError("invalid baidu translate secret") + + q = tool_parameters.get("q", "") + if not q: + raise ValueError("Please input text to translate") + + description_language = tool_parameters.get("description_language", "English") + + salt = str(random.randint(32768, 16777215)) + sign = self._get_sign(appid, secret, salt, q) + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + params = { + "q": q, + "appid": appid, + "salt": salt, + "sign": sign, + } + + try: + response = requests.post(BAIDU_LANGUAGE_URL, params=params, headers=headers) + result = response.json() + if "error_code" not in result: + raise ValueError("Translation service error, please check the network") + + result_text = "" + if result["error_code"] != 0: + result_text = f'{result["error_code"]}: {result["error_msg"]}' + else: + result_text = result["data"]["src"] + result_text = self.mapping_result(description_language, result_text) + + return self.create_text_message(result_text) + except requests.RequestException as e: + raise ValueError(f"Translation service error: {e}") + except Exception: + raise ValueError("Translation service error, please check the network") + + def mapping_result(self, description_language: str, result: str) -> str: + """ + mapping result + """ + mapping = { + "English": { + "zh": "Chinese", + "en": "English", + "jp": "Japanese", + "kor": "Korean", + "th": "Thai", + "vie": "Vietnamese", + "ru": "Russian", + }, + "Chinese": { + "zh": "中文", + "en": "英文", + "jp": "日文", + "kor": "韩文", + "th": "泰语", + "vie": "越南语", + "ru": "俄语", + }, + } + + language_mapping = mapping.get(description_language) + if not language_mapping: + return result + + return language_mapping.get(result, result) diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/language.yaml b/api/core/tools/provider/builtin/baidu_translate/tools/language.yaml new file mode 100644 index 0000000000..60cca2e288 --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/language.yaml @@ -0,0 +1,43 @@ +identity: + name: language + author: Xiao Ley + label: + en_US: Baidu Language + zh_Hans: 百度语种识别 +description: + human: + en_US: A tool for Baidu Language, support Chinese, English, Japanese, Korean, Thai, Vietnamese and Russian + zh_Hans: 使用百度进行语种识别,支持的语种:中文、英语、日语、韩语、泰语、越南语和俄语 + llm: A tool for Baidu Language +parameters: + - name: q + type: string + required: true + label: + en_US: Text content + zh_Hans: 文本内容 + human_description: + en_US: Text content to be recognized + zh_Hans: 需要识别语言的文本内容 + llm_description: Text content to be recognized + form: llm + - name: description_language + type: select + required: true + label: + en_US: Description language + zh_Hans: 描述语言 + human_description: + en_US: Describe the language used to identify the results + zh_Hans: 描述识别结果所用的语言 + default: Chinese + form: form + options: + - value: Chinese + label: + en_US: Chinese + zh_Hans: 中文 + - value: English + label: + en_US: English + zh_Hans: 英语 diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/translate.py b/api/core/tools/provider/builtin/baidu_translate/tools/translate.py new file mode 100644 index 0000000000..7cd816a3bc --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/translate.py @@ -0,0 +1,67 @@ +import random +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.baidu_translate._baidu_translate_tool_base import BaiduTranslateToolBase +from core.tools.tool.builtin_tool import BuiltinTool + + +class BaiduTranslateTool(BuiltinTool, BaiduTranslateToolBase): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + BAIDU_TRANSLATE_URL = "https://fanyi-api.baidu.com/api/trans/vip/translate" + + appid = self.runtime.credentials.get("appid", "") + if not appid: + raise ValueError("invalid baidu translate appid") + + secret = self.runtime.credentials.get("secret", "") + if not secret: + raise ValueError("invalid baidu translate secret") + + q = tool_parameters.get("q", "") + if not q: + raise ValueError("Please input text to translate") + + from_ = tool_parameters.get("from", "") + if not from_: + raise ValueError("Please select source language") + + to = tool_parameters.get("to", "") + if not to: + raise ValueError("Please select destination language") + + salt = str(random.randint(32768, 16777215)) + sign = self._get_sign(appid, secret, salt, q) + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + params = { + "q": q, + "from": from_, + "to": to, + "appid": appid, + "salt": salt, + "sign": sign, + } + try: + response = requests.post(BAIDU_TRANSLATE_URL, params=params, headers=headers) + result = response.json() + + if "trans_result" in result: + result_text = result["trans_result"][0]["dst"] + else: + result_text = f'{result["error_code"]}: {result["error_msg"]}' + + return self.create_text_message(str(result_text)) + except requests.RequestException as e: + raise ValueError(f"Translation service error: {e}") + except Exception: + raise ValueError("Translation service error, please check the network") diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/translate.yaml b/api/core/tools/provider/builtin/baidu_translate/tools/translate.yaml new file mode 100644 index 0000000000..c8ff32cb6b --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/translate.yaml @@ -0,0 +1,275 @@ +identity: + name: translate + author: Xiao Ley + label: + en_US: Translate + zh_Hans: 百度翻译 +description: + human: + en_US: A tool for Baidu Translate + zh_Hans: 百度翻译 + llm: A tool for Baidu Translate +parameters: + - name: q + type: string + required: true + label: + en_US: Text content + zh_Hans: 文本内容 + human_description: + en_US: Text content to be translated + zh_Hans: 需要翻译的文本内容 + llm_description: Text content to be translated + form: llm + - name: from + type: select + required: true + label: + en_US: source language + zh_Hans: 源语言 + human_description: + en_US: The source language of the input text + zh_Hans: 输入的文本的源语言 + default: auto + form: form + options: + - value: auto + label: + en_US: auto + zh_Hans: 自动检测 + - value: zh + label: + en_US: Chinese + zh_Hans: 中文 + - value: en + label: + en_US: English + zh_Hans: 英语 + - value: cht + label: + en_US: Traditional Chinese + zh_Hans: 繁体中文 + - value: yue + label: + en_US: Yue + zh_Hans: 粤语 + - value: wyw + label: + en_US: Wyw + zh_Hans: 文言文 + - value: jp + label: + en_US: Japanese + zh_Hans: 日语 + - value: kor + label: + en_US: Korean + zh_Hans: 韩语 + - value: fra + label: + en_US: French + zh_Hans: 法语 + - value: spa + label: + en_US: Spanish + zh_Hans: 西班牙语 + - value: th + label: + en_US: Thai + zh_Hans: 泰语 + - value: ara + label: + en_US: Arabic + zh_Hans: 阿拉伯语 + - value: ru + label: + en_US: Russian + zh_Hans: 俄语 + - value: pt + label: + en_US: Portuguese + zh_Hans: 葡萄牙语 + - value: de + label: + en_US: German + zh_Hans: 德语 + - value: it + label: + en_US: Italian + zh_Hans: 意大利语 + - value: el + label: + en_US: Greek + zh_Hans: 希腊语 + - value: nl + label: + en_US: Dutch + zh_Hans: 荷兰语 + - value: pl + label: + en_US: Polish + zh_Hans: 波兰语 + - value: bul + label: + en_US: Bulgarian + zh_Hans: 保加利亚语 + - value: est + label: + en_US: Estonian + zh_Hans: 爱沙尼亚语 + - value: dan + label: + en_US: Danish + zh_Hans: 丹麦语 + - value: fin + label: + en_US: Finnish + zh_Hans: 芬兰语 + - value: cs + label: + en_US: Czech + zh_Hans: 捷克语 + - value: rom + label: + en_US: Romanian + zh_Hans: 罗马尼亚语 + - value: slo + label: + en_US: Slovak + zh_Hans: 斯洛文尼亚语 + - value: swe + label: + en_US: Swedish + zh_Hans: 瑞典语 + - value: hu + label: + en_US: Hungarian + zh_Hans: 匈牙利语 + - value: vie + label: + en_US: Vietnamese + zh_Hans: 越南语 + - name: to + type: select + required: true + label: + en_US: destination language + zh_Hans: 目标语言 + human_description: + en_US: The destination language of the input text + zh_Hans: 输入文本的目标语言 + default: en + form: form + options: + - value: zh + label: + en_US: Chinese + zh_Hans: 中文 + - value: en + label: + en_US: English + zh_Hans: 英语 + - value: cht + label: + en_US: Traditional Chinese + zh_Hans: 繁体中文 + - value: yue + label: + en_US: Yue + zh_Hans: 粤语 + - value: wyw + label: + en_US: Wyw + zh_Hans: 文言文 + - value: jp + label: + en_US: Japanese + zh_Hans: 日语 + - value: kor + label: + en_US: Korean + zh_Hans: 韩语 + - value: fra + label: + en_US: French + zh_Hans: 法语 + - value: spa + label: + en_US: Spanish + zh_Hans: 西班牙语 + - value: th + label: + en_US: Thai + zh_Hans: 泰语 + - value: ara + label: + en_US: Arabic + zh_Hans: 阿拉伯语 + - value: ru + label: + en_US: Russian + zh_Hans: 俄语 + - value: pt + label: + en_US: Portuguese + zh_Hans: 葡萄牙语 + - value: de + label: + en_US: German + zh_Hans: 德语 + - value: it + label: + en_US: Italian + zh_Hans: 意大利语 + - value: el + label: + en_US: Greek + zh_Hans: 希腊语 + - value: nl + label: + en_US: Dutch + zh_Hans: 荷兰语 + - value: pl + label: + en_US: Polish + zh_Hans: 波兰语 + - value: bul + label: + en_US: Bulgarian + zh_Hans: 保加利亚语 + - value: est + label: + en_US: Estonian + zh_Hans: 爱沙尼亚语 + - value: dan + label: + en_US: Danish + zh_Hans: 丹麦语 + - value: fin + label: + en_US: Finnish + zh_Hans: 芬兰语 + - value: cs + label: + en_US: Czech + zh_Hans: 捷克语 + - value: rom + label: + en_US: Romanian + zh_Hans: 罗马尼亚语 + - value: slo + label: + en_US: Slovak + zh_Hans: 斯洛文尼亚语 + - value: swe + label: + en_US: Swedish + zh_Hans: 瑞典语 + - value: hu + label: + en_US: Hungarian + zh_Hans: 匈牙利语 + - value: vie + label: + en_US: Vietnamese + zh_Hans: 越南语 diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index 8a24d33428..209d6ecba4 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -1,77 +1,36 @@ import matplotlib.pyplot as plt -from fontTools.ttLib import TTFont -from matplotlib.font_manager import findSystemFonts +from matplotlib.font_manager import FontProperties -from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.chart.tools.line import LinearChartTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + +def set_chinese_font(): + font_list = [ + "PingFang SC", + "SimHei", + "Microsoft YaHei", + "STSong", + "SimSun", + "Arial Unicode MS", + "Noto Sans CJK SC", + "Noto Sans CJK JP", + ] + + for font in font_list: + chinese_font = FontProperties(font) + if chinese_font.get_name() == font: + return chinese_font + + return FontProperties() + + # use a business theme plt.style.use("seaborn-v0_8-darkgrid") plt.rcParams["axes.unicode_minus"] = False - - -def init_fonts(): - fonts = findSystemFonts() - - popular_unicode_fonts = [ - "Arial Unicode MS", - "DejaVu Sans", - "DejaVu Sans Mono", - "DejaVu Serif", - "FreeMono", - "FreeSans", - "FreeSerif", - "Liberation Mono", - "Liberation Sans", - "Liberation Serif", - "Noto Mono", - "Noto Sans", - "Noto Serif", - "Open Sans", - "Roboto", - "Source Code Pro", - "Source Sans Pro", - "Source Serif Pro", - "Ubuntu", - "Ubuntu Mono", - ] - - supported_fonts = [] - - for font_path in fonts: - try: - font = TTFont(font_path) - # get family name - family_name = font["name"].getName(1, 3, 1).toUnicode() - if family_name in popular_unicode_fonts: - supported_fonts.append(family_name) - except: - pass - - plt.rcParams["font.family"] = "sans-serif" - # sort by order of popular_unicode_fonts - for font in popular_unicode_fonts: - if font in supported_fonts: - plt.rcParams["font.sans-serif"] = font - break - - -init_fonts() +font_properties = set_chinese_font() +plt.rcParams["font.family"] = font_properties.get_name() class ChartProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - try: - LinearChartTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id="", - tool_parameters={ - "data": "1,3,5,7,9,2,4,6,8,10", - }, - ) - except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) + pass diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py index a41d34d40f..d4bf713441 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py @@ -1,3 +1,5 @@ +import base64 +import io import json import random import uuid @@ -6,45 +8,48 @@ import httpx from websocket import WebSocket from yarl import URL +from core.file.file_manager import _get_encoded_string +from core.file.models import File + class ComfyUiClient: def __init__(self, base_url: str): self.base_url = URL(base_url) - def get_history(self, prompt_id: str): + def get_history(self, prompt_id: str) -> dict: res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id}) history = res.json()[prompt_id] return history - def get_image(self, filename: str, subfolder: str, folder_type: str): + def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes: response = httpx.get( str(self.base_url / "view"), params={"filename": filename, "subfolder": subfolder, "type": folder_type}, ) return response.content - def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False): - # plan to support img2img in dify 0.10.0 - with open(input_path, "rb") as file: - files = {"image": (name, file, "image/png")} - data = {"type": image_type, "overwrite": str(overwrite).lower()} + def upload_image(self, image_file: File) -> dict: + image_content = base64.b64decode(_get_encoded_string(image_file)) + file = io.BytesIO(image_content) + files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"} + res = httpx.post(str(self.base_url / "upload/image"), files=files) + return res.json() - res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files) - return res - - def queue_prompt(self, client_id: str, prompt: dict): + def queue_prompt(self, client_id: str, prompt: dict) -> str: res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt}) prompt_id = res.json()["prompt_id"] return prompt_id - def open_websocket_connection(self): + def open_websocket_connection(self) -> tuple[WebSocket, str]: client_id = str(uuid.uuid4()) ws = WebSocket() ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}" ws.connect(ws_address) return ws, client_id - def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""): + def set_prompt( + self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = "" + ) -> dict: """ find the first KSampler, then can find the prompt node through it. """ @@ -58,6 +63,10 @@ class ComfyUiClient: if negative_prompt != "": negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt + + if image_name != "": + image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0] + prompt.get(image_loader)["inputs"]["image"] = image_name return prompt def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): @@ -89,7 +98,7 @@ class ComfyUiClient: else: continue - def generate_image_by_prompt(self, prompt: dict): + def generate_image_by_prompt(self, prompt: dict) -> list[bytes]: try: ws, client_id = self.open_websocket_connection() prompt_id = self.queue_prompt(client_id, prompt) diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py index e4df9f8c3b..11320d5d0f 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py @@ -2,10 +2,9 @@ import json from typing import Any from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient from core.tools.tool.builtin_tool import BuiltinTool -from .comfyui_client import ComfyUiClient - class ComfyUIWorkflowTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: @@ -14,13 +13,16 @@ class ComfyUIWorkflowTool(BuiltinTool): positive_prompt = tool_parameters.get("positive_prompt") negative_prompt = tool_parameters.get("negative_prompt") workflow = tool_parameters.get("workflow_json") + image_name = "" + if image := tool_parameters.get("image"): + image_name = comfyui.upload_image(image).get("name") try: origin_prompt = json.loads(workflow) except: return self.create_text_message("the Workflow JSON is not correct") - prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt) + prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name) images = comfyui.generate_image_by_prompt(prompt) result = [] for img in images: diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml index 6342d6d468..55fcdad825 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml @@ -24,6 +24,13 @@ parameters: zh_Hans: 负面提示词 llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English. form: llm + - name: image + type: file + label: + en_US: Input Image + zh_Hans: 输入的图片 + llm_description: The input image, used to transfer to the comfyui workflow to generate another image. + form: llm - name: workflow_json type: string required: true diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py index 3173fb9e13..54bb38755a 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py @@ -2,7 +2,6 @@ from typing import Any from duckduckgo_search import DDGS -from core.file.models import FileTransferMethod from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -20,11 +19,9 @@ class DuckDuckGoImageSearchTool(BuiltinTool): "max_results": tool_parameters.get("max_results"), } response = DDGS().images(**query_dict) - result = [] + markdown_result = "\n\n" + json_result = [] for res in response: - res["transfer_method"] = FileTransferMethod.REMOTE_URL - msg = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=res.get("image"), save_as="", meta=res - ) - result.append(msg) - return result + markdown_result += f"![{res.get('title') or ''}]({res.get('image') or ''})" + json_result.append(self.create_json_message(res)) + return [self.create_text_message(markdown_result)] + json_result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py index d6a0b03d1b..db43790c06 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py @@ -5,9 +5,12 @@ import requests from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -SDURL = { - "sd_3": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-3-medium/text-to-image", - "sd_xl": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image", +SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/image/generations" + +SD_MODELS = { + "sd_3": "stabilityai/stable-diffusion-3-medium", + "sd_xl": "stabilityai/stable-diffusion-xl-base-1.0", + "sd_3.5_large": "stabilityai/stable-diffusion-3-5-large", } @@ -22,9 +25,10 @@ class StableDiffusionTool(BuiltinTool): } model = tool_parameters.get("model", "sd_3") - url = SDURL.get(model) + sd_model = SD_MODELS.get(model) payload = { + "model": sd_model, "prompt": tool_parameters.get("prompt"), "negative_prompt": tool_parameters.get("negative_prompt", ""), "image_size": tool_parameters.get("image_size", "1024x1024"), @@ -34,7 +38,7 @@ class StableDiffusionTool(BuiltinTool): "num_inference_steps": tool_parameters.get("num_inference_steps", 20), } - response = requests.post(url, json=payload, headers=headers) + response = requests.post(SILICONFLOW_API_URL, json=payload, headers=headers) if response.status_code != 200: return self.create_text_message(f"Got Error Response:{response.text}") diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml index dce10adc87..b330c92e16 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml @@ -40,6 +40,9 @@ parameters: - value: sd_xl label: en_US: Stable Diffusion XL + - value: sd_3.5_large + label: + en_US: Stable Diffusion 3.5 Large default: sd_3 label: en_US: Choose Image Model diff --git a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py deleted file mode 100644 index 8effa9818a..0000000000 --- a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py +++ /dev/null @@ -1 +0,0 @@ -VECTORIZER_ICON_PNG = "iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC" # noqa: E501 diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py index 4bd601c0bd..c722cd36c8 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -1,11 +1,12 @@ -from base64 import b64decode from typing import Any, Union from httpx import post +from core.file.enums import FileType +from core.file.file_manager import download +from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.vectorizer.tools.test_data import VECTORIZER_ICON_PNG +from core.tools.errors import ToolParameterValidationError from core.tools.tool.builtin_tool import BuiltinTool @@ -16,30 +17,30 @@ class VectorizerTool(BuiltinTool): """ invoke tools """ - api_key_name = self.runtime.credentials.get("api_key_name", None) - api_key_value = self.runtime.credentials.get("api_key_value", None) + api_key_name = self.runtime.credentials.get("api_key_name") + api_key_value = self.runtime.credentials.get("api_key_value") mode = tool_parameters.get("mode", "test") - if mode == "production": - mode = "preview" - - if not api_key_name or not api_key_value: - raise ToolProviderCredentialValidationError("Please input api key name and value") + # image file for workflow mode + image = tool_parameters.get("image") + if image and image.type != FileType.IMAGE: + raise ToolParameterValidationError("Not a valid image") + # image_id for agent mode image_id = tool_parameters.get("image_id", "") - if not image_id: - return self.create_text_message("Please input image id") - if image_id.startswith("__test_"): - image_binary = b64decode(VECTORIZER_ICON_PNG) - else: + if image_id: image_binary = self.get_variable_file(self.VariableKey.IMAGE) if not image_binary: return self.create_text_message("Image not found, please request user to generate image firstly.") + elif image: + image_binary = download(image) + else: + raise ToolParameterValidationError("Please provide either image or image_id") response = post( "https://vectorizer.ai/api/v1/vectorize", + data={"mode": mode}, files={"image": image_binary}, - data={"mode": mode} if mode == "test" else {}, auth=(api_key_name, api_key_value), timeout=30, ) @@ -59,11 +60,23 @@ class VectorizerTool(BuiltinTool): return [ ToolParameter.get_simple_instance( name="image_id", - llm_description=f"the image id that you want to vectorize, \ - and the image id should be specified in \ + llm_description=f"the image_id that you want to vectorize, \ + and the image_id should be specified in \ {[i.name for i in self.list_default_image_variables()]}", type=ToolParameter.ToolParameterType.SELECT, - required=True, + required=False, options=[i.name for i in self.list_default_image_variables()], - ) + ), + ToolParameter( + name="image", + label=I18nObject(en_US="image", zh_Hans="image"), + human_description=I18nObject( + en_US="The image to be converted.", + zh_Hans="要转换的图片。", + ), + type=ToolParameter.ToolParameterType.FILE, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="you should not input this parameter. just input the image_id.", + required=False, + ), ] diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml index 4b4fb9e245..0afd1c201f 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml @@ -4,14 +4,21 @@ identity: label: en_US: Vectorizer.AI zh_Hans: Vectorizer.AI - pt_BR: Vectorizer.AI description: human: en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。 - pt_BR: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. llm: A tool for converting images to SVG vectors. you should input the image id as the input of this tool. the image id can be got from parameters. parameters: + - name: image + type: file + label: + en_US: image + human_description: + en_US: The image to be converted. + zh_Hans: 要转换的图片。 + llm_description: you should not input this parameter. just input the image_id. + form: llm - name: mode type: select required: true @@ -20,19 +27,15 @@ parameters: label: en_US: production zh_Hans: 生产模式 - pt_BR: production - value: test label: en_US: test zh_Hans: 测试模式 - pt_BR: test default: test label: en_US: Mode zh_Hans: 模式 - pt_BR: Mode human_description: en_US: It is free to integrate with and test out the API in test mode, no subscription required. zh_Hans: 在测试模式下,可以免费测试API。 - pt_BR: It is free to integrate with and test out the API in test mode, no subscription required. form: form diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 3b868572f9..8140348723 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -1,5 +1,7 @@ from typing import Any +from core.file import File +from core.file.enums import FileTransferMethod, FileType from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -7,6 +9,12 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class VectorizerProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: + test_img = File( + tenant_id="__test_123", + remote_url="https://cloud.dify.ai/logo/logo-site.png", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + ) try: VectorizerTool().fork_tool_runtime( runtime={ @@ -14,7 +22,7 @@ class VectorizerProvider(BuiltinToolProviderController): } ).invoke( user_id="", - tool_parameters={"mode": "test", "image_id": "__test_123"}, + tool_parameters={"mode": "test", "image": test_img}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml b/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml index 1257f8d285..94dae20876 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml @@ -4,11 +4,9 @@ identity: label: en_US: Vectorizer.AI zh_Hans: Vectorizer.AI - pt_BR: Vectorizer.AI description: en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。 - pt_BR: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. icon: icon.png tags: - productivity @@ -20,15 +18,12 @@ credentials_for_provider: label: en_US: Vectorizer.AI API Key name zh_Hans: Vectorizer.AI API Key name - pt_BR: Vectorizer.AI API Key name placeholder: en_US: Please input your Vectorizer.AI ApiKey name zh_Hans: 请输入你的 Vectorizer.AI ApiKey name - pt_BR: Please input your Vectorizer.AI ApiKey name help: en_US: Get your Vectorizer.AI API Key from Vectorizer.AI. zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。 - pt_BR: Get your Vectorizer.AI API Key from Vectorizer.AI. url: https://vectorizer.ai/api api_key_value: type: secret-input @@ -36,12 +31,9 @@ credentials_for_provider: label: en_US: Vectorizer.AI API Key zh_Hans: Vectorizer.AI API Key - pt_BR: Vectorizer.AI API Key placeholder: en_US: Please input your Vectorizer.AI ApiKey zh_Hans: 请输入你的 Vectorizer.AI ApiKey - pt_BR: Please input your Vectorizer.AI ApiKey help: en_US: Get your Vectorizer.AI API Key from Vectorizer.AI. zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。 - pt_BR: Get your Vectorizer.AI API Key from Vectorizer.AI. diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 9e984732b7..63f7775164 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -242,11 +242,15 @@ class ToolManager: parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: # check file types - if parameter.type in { - ToolParameter.ToolParameterType.SYSTEM_FILES, - ToolParameter.ToolParameterType.FILE, - ToolParameter.ToolParameterType.FILES, - }: + if ( + parameter.type + in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + } + and parameter.required + ): raise ValueError(f"file type parameter {parameter.name} not supported in agent") if parameter.form == ToolParameter.ToolParameterForm.FORM: diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ada0b14ce4..8f58af00ef 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -130,15 +130,14 @@ class GraphEngine: yield GraphRunStartedEvent() try: - stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor] if self.init_params.workflow_type == WorkflowType.CHAT: - stream_processor_cls = AnswerStreamProcessor + stream_processor = AnswerStreamProcessor( + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + ) else: - stream_processor_cls = EndStreamProcessor - - stream_processor = stream_processor_cls( - graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool - ) + stream_processor = EndStreamProcessor( + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + ) # run graph generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index bce28c5fcb..bc4b056148 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -149,10 +149,10 @@ class AnswerStreamGeneratorRouter: source_node_id = edge.source_node_id source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") if source_node_type in { - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value, - NodeType.ITERATION.value, + NodeType.ANSWER, + NodeType.IF_ELSE, + NodeType.QUESTION_CLASSIFIER, + NodeType.ITERATION, }: answer_dependencies[answer_node_id].append(source_node_id) else: diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index e3889941ca..8a768088da 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -22,7 +22,7 @@ class AnswerStreamProcessor(StreamProcessor): super().__init__(graph, variable_pool) self.generate_routes = graph.answer_stream_generate_routes self.route_position = {} - for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + for answer_node_id in self.generate_routes.answer_generate_route: self.route_position[answer_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 52d0358c76..36c3fe180a 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -41,7 +41,6 @@ class StreamProcessor(ABC): continue else: unreachable_first_node_ids.append(edge.target_node_id) - unreachable_first_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) for node_id in unreachable_first_node_ids: self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index e543d02dd7..a05cc44c99 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from enum import Enum from pydantic import BaseModel, Field @@ -32,7 +33,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk): type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR """generate route chunk type""" - value_selector: list[str] = Field(..., description="value selector") + value_selector: Sequence[str] = Field(..., description="value selector") class TextGenerateRouteChunk(GenerateRouteChunk): diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 3efcc373b1..9e09b6d29a 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -1,5 +1,6 @@ import csv import io +import json import docx import pandas as pd @@ -75,36 +76,62 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): ) -def _extract_text(*, file_content: bytes, mime_type: str) -> str: +def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: """Extract text from a file based on its MIME type.""" - if mime_type.startswith("text/plain") or mime_type in {"text/html", "text/htm", "text/markdown", "text/xml"}: - return _extract_text_from_plain_text(file_content) - elif mime_type == "application/pdf": - return _extract_text_from_pdf(file_content) - elif mime_type in { - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "application/msword", - }: - return _extract_text_from_doc(file_content) - elif mime_type == "text/csv": - return _extract_text_from_csv(file_content) - elif mime_type in { - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - "application/vnd.ms-excel", - }: - return _extract_text_from_excel(file_content) - elif mime_type == "application/vnd.ms-powerpoint": - return _extract_text_from_ppt(file_content) - elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation": - return _extract_text_from_pptx(file_content) - elif mime_type == "application/epub+zip": - return _extract_text_from_epub(file_content) - elif mime_type == "message/rfc822": - return _extract_text_from_eml(file_content) - elif mime_type == "application/vnd.ms-outlook": - return _extract_text_from_msg(file_content) - else: - raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") + match mime_type: + case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml": + return _extract_text_from_plain_text(file_content) + case "application/pdf": + return _extract_text_from_pdf(file_content) + case "application/vnd.openxmlformats-officedocument.wordprocessingml.document" | "application/msword": + return _extract_text_from_doc(file_content) + case "text/csv": + return _extract_text_from_csv(file_content) + case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel": + return _extract_text_from_excel(file_content) + case "application/vnd.ms-powerpoint": + return _extract_text_from_ppt(file_content) + case "application/vnd.openxmlformats-officedocument.presentationml.presentation": + return _extract_text_from_pptx(file_content) + case "application/epub+zip": + return _extract_text_from_epub(file_content) + case "message/rfc822": + return _extract_text_from_eml(file_content) + case "application/vnd.ms-outlook": + return _extract_text_from_msg(file_content) + case "application/json": + return _extract_text_from_json(file_content) + case _: + raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") + + +def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str: + """Extract text from a file based on its file extension.""" + match file_extension: + case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml": + return _extract_text_from_plain_text(file_content) + case ".json": + return _extract_text_from_json(file_content) + case ".pdf": + return _extract_text_from_pdf(file_content) + case ".doc" | ".docx": + return _extract_text_from_doc(file_content) + case ".csv": + return _extract_text_from_csv(file_content) + case ".xls" | ".xlsx": + return _extract_text_from_excel(file_content) + case ".ppt": + return _extract_text_from_ppt(file_content) + case ".pptx": + return _extract_text_from_pptx(file_content) + case ".epub": + return _extract_text_from_epub(file_content) + case ".eml": + return _extract_text_from_eml(file_content) + case ".msg": + return _extract_text_from_msg(file_content) + case _: + raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}") def _extract_text_from_plain_text(file_content: bytes) -> str: @@ -114,6 +141,14 @@ def _extract_text_from_plain_text(file_content: bytes) -> str: raise TextExtractionError("Failed to decode plain text file") from e +def _extract_text_from_json(file_content: bytes) -> str: + try: + json_data = json.loads(file_content.decode("utf-8")) + return json.dumps(json_data, indent=2, ensure_ascii=False) + except (UnicodeDecodeError, json.JSONDecodeError) as e: + raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e + + def _extract_text_from_pdf(file_content: bytes) -> str: try: pdf_file = io.BytesIO(file_content) @@ -156,10 +191,13 @@ def _download_file_content(file: File) -> bytes: def _extract_text_from_file(file: File): - if file.mime_type is None: - raise UnsupportedFileTypeError("Unable to determine file type: MIME type is missing") file_content = _download_file_content(file) - extracted_text = _extract_text(file_content=file_content, mime_type=file.mime_type) + if file.extension: + extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension) + elif file.mime_type: + extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type) + else: + raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing") return extracted_text @@ -172,7 +210,7 @@ def _extract_text_from_csv(file_content: bytes) -> str: if not rows: return "" - # Create markdown table + # Create Markdown table markdown_table = "| " + " | ".join(rows[0]) + " |\n" markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n" for row in rows[1:]: @@ -192,7 +230,7 @@ def _extract_text_from_excel(file_content: bytes) -> str: # Drop rows where all elements are NaN df.dropna(how="all", inplace=True) - # Convert DataFrame to markdown table + # Convert DataFrame to Markdown table markdown_table = df.to_markdown(index=False) return markdown_table except Exception as e: diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 0270d7e0fd..6872478299 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -33,7 +33,7 @@ class Executor: params: Mapping[str, str] | None content: str | bytes | None data: Mapping[str, Any] | None - files: Mapping[str, bytes] | None + files: Mapping[str, tuple[str | None, bytes, str]] | None json: Any headers: dict[str, str] auth: HttpRequestNodeAuthorization @@ -141,7 +141,11 @@ class Executor: files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()} files = {k: v for k, v in files.items() if v is not None} files = {k: variable.value for k, variable in files.items()} - files = {k: file_manager.download(v) for k, v in files.items() if v.related_id is not None} + files = { + k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream") + for k, v in files.items() + if v.related_id is not None + } self.data = form_data self.files = files diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 94aa8c5eab..abf77f3339 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -127,9 +127,10 @@ class LLMNode(BaseNode[LLMNodeData]): context=context, memory=memory, model_config=model_config, - vision_detail=self.node_data.vision.configs.detail, prompt_template=self.node_data.prompt_template, memory_config=self.node_data.memory, + vision_enabled=self.node_data.vision.enabled, + vision_detail=self.node_data.vision.configs.detail, ) process_data = { @@ -518,6 +519,7 @@ class LLMNode(BaseNode[LLMNodeData]): model_config: ModelConfigWithCredentialsEntity, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, memory_config: MemoryConfig | None = None, + vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = inputs or {} @@ -542,6 +544,10 @@ class LLMNode(BaseNode[LLMNodeData]): if not isinstance(prompt_message.content, str): prompt_message_content = [] for content_item in prompt_message.content or []: + # Skip image if vision is disabled + if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: + continue + if isinstance(content_item, ImagePromptMessageContent): # Override vision config if LLM node has vision config, # cuz vision detail is related to the configuration from FileUpload feature. diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index e6af453dcf..ee160e7c69 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -88,6 +88,7 @@ class QuestionClassifierNode(LLMNode): memory=memory, model_config=model_config, files=files, + vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, ) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index b9b019373d..504899c276 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,6 +1,7 @@ from datetime import timedelta from celery import Celery, Task +from celery.schedules import crontab from flask import Flask from configs import dify_config @@ -55,6 +56,8 @@ def init_app(app: Flask) -> Celery: imports = [ "schedule.clean_embedding_cache_task", "schedule.clean_unused_datasets_task", + "schedule.create_tidb_serverless_task", + "schedule.update_tidb_serverless_status_task", ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME beat_schedule = { @@ -66,6 +69,14 @@ def init_app(app: Flask) -> Celery: "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", "schedule": timedelta(days=day), }, + "create_tidb_serverless_task": { + "task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task", + "schedule": crontab(minute="0", hour="*"), + }, + "update_tidb_serverless_status_task": { + "task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task", + "schedule": crontab(minute="30", hour="*"), + }, } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 01c1000e50..67635b129e 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -36,12 +36,9 @@ class AliyunOssStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - obj = self.client.get_object(self.__wrapper_folder_filename(filename)) - while chunk := obj.read(4096): - yield chunk - - return generate() + obj = self.client.get_object(self.__wrapper_folder_filename(filename)) + while chunk := obj.read(4096): + yield chunk def download(self, filename, target_filepath): self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index cb67313bb2..ab2d0fba3b 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -62,17 +62,14 @@ class AwsS3Storage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - try: - response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response["Body"].iter_chunks() - except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": - raise FileNotFoundError("File not found") - else: - raise - - return generate() + try: + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].iter_chunks() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("File not found") + else: + raise def download(self, filename, target_filepath): self.client.download_file(self.bucket_name, filename, target_filepath) diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 477507feda..11a7544274 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -32,13 +32,9 @@ class AzureBlobStorage(BaseStorage): def load_stream(self, filename: str) -> Generator: client = self._sync_client() - - def generate(filename: str = filename) -> Generator: - blob = client.get_blob_client(container=self.bucket_name, blob=filename) - blob_data = blob.download_blob() - yield from blob_data.chunks() - - return generate(filename) + blob = client.get_blob_client(container=self.bucket_name, blob=filename) + blob_data = blob.download_blob() + yield from blob_data.chunks() def download(self, filename, target_filepath): client = self._sync_client() diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index cd69439749..e0d2140e91 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -39,12 +39,9 @@ class BaiduObsStorage(BaseStorage): return response.data.read() def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data - while chunk := response.read(4096): - yield chunk - - return generate() + response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data + while chunk := response.read(4096): + yield chunk def download(self, filename, target_filepath): self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath) diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index e90392a6ba..26b662d2f0 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -39,14 +39,11 @@ class GoogleCloudStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - bucket = self.client.get_bucket(self.bucket_name) - blob = bucket.get_blob(filename) - with blob.open(mode="rb") as blob_stream: - while chunk := blob_stream.read(4096): - yield chunk - - return generate() + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.get_blob(filename) + with blob.open(mode="rb") as blob_stream: + while chunk := blob_stream.read(4096): + yield chunk def download(self, filename, target_filepath): bucket = self.client.get_bucket(self.bucket_name) diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 3c443d87ac..20be70ef83 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -27,12 +27,9 @@ class HuaweiObsStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response - while chunk := response.read(4096): - yield chunk - - return generate() + response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response + while chunk := response.read(4096): + yield chunk def download(self, filename, target_filepath): self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath) diff --git a/api/extensions/storage/local_fs_storage.py b/api/extensions/storage/local_fs_storage.py index e458b3ce8a..5a495ca4d4 100644 --- a/api/extensions/storage/local_fs_storage.py +++ b/api/extensions/storage/local_fs_storage.py @@ -19,68 +19,44 @@ class LocalFsStorage(BaseStorage): folder = os.path.join(current_app.root_path, folder) self.folder = folder - def save(self, filename, data): + def _build_filepath(self, filename: str) -> str: + """Build the full file path based on the folder and filename.""" if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename + return self.folder + filename else: - filename = self.folder + "/" + filename + return self.folder + "/" + filename - folder = os.path.dirname(filename) + def save(self, filename, data): + filepath = self._build_filepath(filename) + folder = os.path.dirname(filepath) os.makedirs(folder, exist_ok=True) - - Path(os.path.join(os.getcwd(), filename)).write_bytes(data) + Path(os.path.join(os.getcwd(), filepath)).write_bytes(data) def load_once(self, filename: str) -> bytes: - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - if not os.path.exists(filename): + filepath = self._build_filepath(filename) + if not os.path.exists(filepath): raise FileNotFoundError("File not found") - - data = Path(filename).read_bytes() - return data + return Path(filepath).read_bytes() def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - if not os.path.exists(filename): - raise FileNotFoundError("File not found") - - with open(filename, "rb") as f: - while chunk := f.read(4096): # Read in chunks of 4KB - yield chunk - - return generate() + filepath = self._build_filepath(filename) + if not os.path.exists(filepath): + raise FileNotFoundError("File not found") + with open(filepath, "rb") as f: + while chunk := f.read(4096): # Read in chunks of 4KB + yield chunk def download(self, filename, target_filepath): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - if not os.path.exists(filename): + filepath = self._build_filepath(filename) + if not os.path.exists(filepath): raise FileNotFoundError("File not found") - - shutil.copyfile(filename, target_filepath) + shutil.copyfile(filepath, target_filepath) def exists(self, filename): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - return os.path.exists(filename) + filepath = self._build_filepath(filename) + return os.path.exists(filepath) def delete(self, filename): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - if os.path.exists(filename): - os.remove(filename) + filepath = self._build_filepath(filename) + if os.path.exists(filepath): + os.remove(filepath) diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index e4f50b34e9..b59f83b8de 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -36,17 +36,14 @@ class OracleOCIStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - try: - response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response["Body"].iter_chunks() - except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": - raise FileNotFoundError("File not found") - else: - raise - - return generate() + try: + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].iter_chunks() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("File not found") + else: + raise def download(self, filename, target_filepath): self.client.download_file(self.bucket_name, filename, target_filepath) diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py index 1119244574..9f7c69a9ae 100644 --- a/api/extensions/storage/supabase_storage.py +++ b/api/extensions/storage/supabase_storage.py @@ -36,17 +36,14 @@ class SupabaseStorage(BaseStorage): return content def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - result = self.client.storage.from_(self.bucket_name).download(filename) - byte_stream = io.BytesIO(result) - while chunk := byte_stream.read(4096): # Read in chunks of 4KB - yield chunk - - return generate() + result = self.client.storage.from_(self.bucket_name).download(filename) + byte_stream = io.BytesIO(result) + while chunk := byte_stream.read(4096): # Read in chunks of 4KB + yield chunk def download(self, filename, target_filepath): result = self.client.storage.from_(self.bucket_name).download(filename) - Path(result).write_bytes(result) + Path(target_filepath).write_bytes(result) def exists(self, filename): result = self.client.storage.from_(self.bucket_name).list(filename) diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index 8fd8e703a1..13a6c9239c 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -29,11 +29,8 @@ class TencentCosStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response["Body"].get_stream(chunk_size=4096) - - return generate() + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].get_stream(chunk_size=4096) def download(self, filename, target_filepath): response = self.client.get_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index 389c5630e3..de82be04ea 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -27,12 +27,9 @@ class VolcengineTosStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - response = self.client.get_object(bucket=self.bucket_name, key=filename) - while chunk := response.read(4096): - yield chunk - - return generate() + response = self.client.get_object(bucket=self.bucket_name, key=filename) + while chunk := response.read(4096): + yield chunk def download(self, filename, target_filepath): self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index fa88e2b4fe..ead7b9a8b3 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -179,27 +179,19 @@ def _build_from_remote_url( if not url: raise ValueError("Invalid file url") + mime_type = mimetypes.guess_type(url)[0] or "" + file_size = -1 + filename = url.split("/")[-1].split("?")[0] or "unknown_file" + resp = ssrf_proxy.head(url, follow_redirects=True) if resp.status_code == httpx.codes.OK: - # Try to extract filename from response headers or URL - content_disposition = resp.headers.get("Content-Disposition") - if content_disposition: + if content_disposition := resp.headers.get("Content-Disposition"): filename = content_disposition.split("filename=")[-1].strip('"') - else: - filename = url.split("/")[-1].split("?")[0] - # Create the File object - file_size = int(resp.headers.get("Content-Length", -1)) - mime_type = str(resp.headers.get("Content-Type", "")) - else: - filename = "" - file_size = -1 - mime_type = "" + file_size = int(resp.headers.get("Content-Length", file_size)) + mime_type = mime_type or str(resp.headers.get("Content-Type", "")) - # If filename is empty, set a default one - if not filename: - filename = "unknown_file" # Determine file extension - extension = "." + filename.split(".")[-1] if "." in filename else ".bin" + extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" if not mime_type: mime_type, _ = mimetypes.guess_type(url) diff --git a/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py b/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py new file mode 100644 index 0000000000..ca2e410442 --- /dev/null +++ b/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py @@ -0,0 +1,51 @@ +"""add-tidb-auth-binding + +Revision ID: 0251a1c768cc +Revises: 63a83fcf12ba +Create Date: 2024-08-15 09:56:59.012490 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '0251a1c768cc' +down_revision = 'bbadea11becb' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tidb_auth_bindings', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=True), + sa.Column('cluster_id', sa.String(length=255), nullable=False), + sa.Column('cluster_name', sa.String(length=255), nullable=False), + sa.Column('active', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'CREATING'::character varying"), nullable=False), + sa.Column('account', sa.String(length=255), nullable=False), + sa.Column('password', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tidb_auth_bindings_pkey') + ) + with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: + batch_op.create_index('tidb_auth_bindings_active_idx', ['active'], unique=False) + batch_op.create_index('tidb_auth_bindings_status_idx', ['status'], unique=False) + batch_op.create_index('tidb_auth_bindings_created_at_idx', ['created_at'], unique=False) + batch_op.create_index('tidb_auth_bindings_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: + batch_op.drop_index('tidb_auth_bindings_tenant_idx') + batch_op.drop_index('tidb_auth_bindings_created_at_idx') + batch_op.drop_index('tidb_auth_bindings_active_idx') + batch_op.drop_index('tidb_auth_bindings_status_idx') + op.drop_table('tidb_auth_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py b/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py new file mode 100644 index 0000000000..9daf148bc4 --- /dev/null +++ b/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py @@ -0,0 +1,42 @@ +"""add_white_list + +Revision ID: 43fa78bc3b7d +Revises: 0251a1c768cc +Create Date: 2024-10-22 09:59:23.713716 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '43fa78bc3b7d' +down_revision = '0251a1c768cc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('whitelists', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=True), + sa.Column('category', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='whitelists_pkey') + ) + with op.batch_alter_table('whitelists', schema=None) as batch_op: + batch_op.create_index('whitelists_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('whitelists', schema=None) as batch_op: + batch_op.drop_index('whitelists_tenant_idx') + + op.drop_table('whitelists') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 4e2ccab7e8..a1a626d7e4 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -560,10 +560,28 @@ class DocumentSegment(db.Model): ) def get_sign_content(self): - pattern = r"/files/([a-f0-9\-]+)/file-preview" - text = self.content - matches = re.finditer(pattern, text) signed_urls = [] + text = self.content + + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview" + matches = re.finditer(pattern, text) + for match in matches: + upload_file_id = match.group(1) + nonce = os.urandom(16).hex() + timestamp = str(int(time.time())) + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + signed_url = f"{match.group(0)}?{params}" + signed_urls.append((match.start(), match.end(), signed_url)) + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview" + matches = re.finditer(pattern, text) for match in matches: upload_file_id = match.group(1) nonce = os.urandom(16).hex() @@ -704,6 +722,38 @@ class DatasetCollectionBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) +class TidbAuthBinding(db.Model): + __tablename__ = "tidb_auth_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), + db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), + db.Index("tidb_auth_bindings_active_idx", "active"), + db.Index("tidb_auth_bindings_created_at_idx", "created_at"), + db.Index("tidb_auth_bindings_status_idx", "status"), + ) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=True) + cluster_id = db.Column(db.String(255), nullable=False) + cluster_name = db.Column(db.String(255), nullable=False) + active = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) + account = db.Column(db.String(255), nullable=False) + password = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + +class Whitelist(db.Model): + __tablename__ = "whitelists" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="whitelists_pkey"), + db.Index("whitelists_tenant_idx", "tenant_id"), + ) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=True) + category = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + class DatasetPermission(db.Model): __tablename__ = "dataset_permissions" __table_args__ = ( diff --git a/api/models/model.py b/api/models/model.py index e289423d16..3bd5886d75 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -977,6 +977,9 @@ class Message(db.Model): config=FileExtraConfig(), ) elif message_file.transfer_method == "tool_file": + if message_file.upload_file_id is None: + assert message_file.url is not None + message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0] mapping = { "id": message_file.id, "type": message_file.type, @@ -1001,6 +1004,7 @@ class Message(db.Model): for (file, message_file) in zip(files, message_files) ] + db.session.commit() return result @property diff --git a/api/poetry.lock b/api/poetry.lock index a0d418ba7b..e1e5a6410b 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -847,13 +847,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.46" +version = "1.35.47" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.46-py3-none-any.whl", hash = "sha256:8bbc9a55cae65a8db7f2e33ff087f4dbfc13fce868e8e3c5273ce9af367a555a"}, - {file = "botocore-1.35.46.tar.gz", hash = "sha256:8c0ff5fdd611a28f5752189d171c69690dbc484fa06d74376890bb0543ec3dc1"}, + {file = "botocore-1.35.47-py3-none-any.whl", hash = "sha256:05f4493119a96799ff84d43e78691efac3177e1aec8840cca99511de940e342a"}, + {file = "botocore-1.35.47.tar.gz", hash = "sha256:f8f703463d3cd8b6abe2bedc443a7ab29f0e2ff1588a2e83164b108748645547"}, ] [package.dependencies] @@ -1801,6 +1801,46 @@ requests = ">=2.8" six = "*" xmltodict = "*" +[[package]] +name = "couchbase" +version = "4.3.3" +description = "Python Client for Couchbase" +optional = false +python-versions = ">=3.7" +files = [ + {file = "couchbase-4.3.3-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:d8069e4f01332859d56cca597874645c914699162b3979d1b432f0dfc186b124"}, + {file = "couchbase-4.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1caa6cfef49c785b35b1702102f718227f351df87bba2694b9334520c41e9eb5"}, + {file = "couchbase-4.3.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f4a9a65c44935249fa078fb90a3c28ea71da9d2d5889fcd514b12d0538010ae0"}, + {file = "couchbase-4.3.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4f144b8c482c18283d8e419b844630d41f3249b07d43d40b5e3535444e57d0fb"}, + {file = "couchbase-4.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1c534fba6fdc7cf47eed9dee8a57d1e9eb867bf008574e321fa380a77cebf32f"}, + {file = "couchbase-4.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:b841be06e0e4370b69ebef6bca3409c378186f7d6e964cd645ba18e97216c022"}, + {file = "couchbase-4.3.3-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:eee7a73b3acbdc78ae314fddf7f975b3c9e05df07df255f4dcc878939a2abae0"}, + {file = "couchbase-4.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:53417cafcf90ff4e2fd81ebba2a08b7ad56f17160d1c5019ad3b09c758aeb363"}, + {file = "couchbase-4.3.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0cefd13bea8b0f150f1b9d27fd7614f971f77419b31817781d26ba315ed658bb"}, + {file = "couchbase-4.3.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:78fa1054d7740e2fe38fce0a2aab4e9a2d30263d894e0615ee5df297f02f59a3"}, + {file = "couchbase-4.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:eb093899cfad5a7472258a9b6a57775dbf23a6e0180241507ba89ce3ab241e41"}, + {file = "couchbase-4.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:f7cfbdc699af5715f49365ffbb05a6a7366a534c0d7161edf270ad3e735a6c5d"}, + {file = "couchbase-4.3.3-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:58352cae9b8affdaa2ac012e0a03c8c2632ee6297a878232888b4e0360d0d5df"}, + {file = "couchbase-4.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:728e7e3b5e1682706cb9d63993d289226d02a25089527b8ecb4e3889dabc38cf"}, + {file = "couchbase-4.3.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:73014bf098cf14187a39cc13453e0d859c1d54568df28f69cc308a9a5f24feb2"}, + {file = "couchbase-4.3.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a743375804068ae01b73c916bfca738764c8c12f381bb399ef04e784935856a1"}, + {file = "couchbase-4.3.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:394c122cfe02a76a99e7d5178e64129f6da49843225e78d8629abcab556c24af"}, + {file = "couchbase-4.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:bf85d7a5cda548d9801614651206068b4445fa37972e62b14d7521a958198693"}, + {file = "couchbase-4.3.3-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:92d23c9cedd571631070791f2afee0e3d7d8c9ce1bf2ea6e9a4f2fdbc37a0f1e"}, + {file = "couchbase-4.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:38c42eb29a73cce2998ae5df45bd61b16dce9765d3bff968ec5cf6a622faa291"}, + {file = "couchbase-4.3.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:afed137bf0edc642d7b201b6ab7b1e7117bb4c8eac6b2f253cc6e106f334a2a1"}, + {file = "couchbase-4.3.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:954d991377d47883aaf903934c5d0f19577680a2abf80d3ce5bb9b3c80991fc7"}, + {file = "couchbase-4.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5552b9fa684630698dc98d6f3b1082540634c1b7ad5bf53b843b5da57b0169c"}, + {file = "couchbase-4.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:f88f2b7e0c894f7237d9f3fb5c46abc44b8151a97b3ca8e75f57d23ebf59f9da"}, + {file = "couchbase-4.3.3-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:769e1e2367ea1d4de181fcd4b4e353e9abef97d15b581a6c5aea49ece3dc7d59"}, + {file = "couchbase-4.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:47f59a0b35ffce060583fd11f98f049f3b70701cf14aab9ac092594aca486aeb"}, + {file = "couchbase-4.3.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:440bb93d611827ba0ea2403c6f204fe931467a6cb5811f0e03bf1779204ef843"}, + {file = "couchbase-4.3.3-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cdb4dde62e1d41c0b8707121ab68fa78b7a1508541bd48fc850be396f91bc8d9"}, + {file = "couchbase-4.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7f8cf45f317b39cc19db5c67b565662f08d6c90305b3aa14e04bc22707258213"}, + {file = "couchbase-4.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:c97d48ad486c8f201b4482d5594258f949369cb44792ed148d5159a3d12ae21b"}, + {file = "couchbase-4.3.3.tar.gz", hash = "sha256:27808500551564b39b46943cf3daab572694889c1eb638425d363edb48b20da7"}, +] + [[package]] name = "coverage" version = "7.2.7" @@ -2342,6 +2382,20 @@ files = [ {file = "et_xmlfile-1.1.0.tar.gz", hash = "sha256:8eb9e2bc2f8c97e37a2dc85a09ecdcdec9d8a396530a6d5a33b30b9a92da0c5c"}, ] +[[package]] +name = "eval-type-backport" +version = "0.2.0" +description = "Like `typing._eval_type`, but lets older Python versions use newer typing features." +optional = false +python-versions = ">=3.8" +files = [ + {file = "eval_type_backport-0.2.0-py3-none-any.whl", hash = "sha256:ac2f73d30d40c5a30a80b8739a789d6bb5e49fdffa66d7912667e2015d9c9933"}, + {file = "eval_type_backport-0.2.0.tar.gz", hash = "sha256:68796cfbc7371ebf923f03bdf7bef415f3ec098aeced24e054b253a0e78f7b37"}, +] + +[package.extras] +tests = ["pytest"] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -4279,6 +4333,17 @@ files = [ [package.dependencies] ply = "*" +[[package]] +name = "jsonpath-python" +version = "1.0.6" +description = "A more powerful JSONPath implementation in modern python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "jsonpath-python-1.0.6.tar.gz", hash = "sha256:dd5be4a72d8a2995c3f583cf82bf3cd1a9544cfdabf2d22595b67aff07349666"}, + {file = "jsonpath_python-1.0.6-py3-none-any.whl", hash = "sha256:1e3b78df579f5efc23565293612decee04214609208a2335884b3ee3f786b575"}, +] + [[package]] name = "jsonschema" version = "4.23.0" @@ -5265,23 +5330,6 @@ files = [ msal = ">=1.29,<2" portalocker = ">=1.4,<3" -[[package]] -name = "msg-parser" -version = "1.2.0" -description = "This module enables reading, parsing and converting Microsoft Outlook MSG E-Mail files." -optional = false -python-versions = ">=3.4" -files = [ - {file = "msg_parser-1.2.0-py2.py3-none-any.whl", hash = "sha256:d47a2f0b2a359cb189fad83cc991b63ea781ecc70d91410324273fbf93e95375"}, - {file = "msg_parser-1.2.0.tar.gz", hash = "sha256:0de858d4fcebb6c8f6f028da83a17a20fe01cdce67c490779cf43b3b0162aa66"}, -] - -[package.dependencies] -olefile = ">=0.46" - -[package.extras] -rtf = ["compressed-rtf (>=1.0.5)"] - [[package]] name = "msrest" version = "0.7.1" @@ -5457,6 +5505,17 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "nest-asyncio" +version = "1.6.0" +description = "Patch asyncio to allow nested event loops" +optional = false +python-versions = ">=3.5" +files = [ + {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, + {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, +] + [[package]] name = "newspaper3k" version = "0.2.8" @@ -5485,13 +5544,13 @@ tldextract = ">=2.0.1" [[package]] name = "nltk" -version = "3.8.1" +version = "3.9.1" description = "Natural Language Toolkit" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"}, - {file = "nltk-3.8.1.zip", hash = "sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3"}, + {file = "nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1"}, + {file = "nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868"}, ] [package.dependencies] @@ -5780,13 +5839,13 @@ sympy = "*" [[package]] name = "openai" -version = "1.52.1" +version = "1.52.2" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.52.1-py3-none-any.whl", hash = "sha256:f23e83df5ba04ee0e82c8562571e8cb596cd88f9a84ab783e6c6259e5ffbfb4a"}, - {file = "openai-1.52.1.tar.gz", hash = "sha256:383b96c7e937cbec23cad5bf5718085381e4313ca33c5c5896b54f8e1b19d144"}, + {file = "openai-1.52.2-py3-none-any.whl", hash = "sha256:57e9e37bc407f39bb6ec3a27d7e8fb9728b2779936daa1fcf95df17d3edfaccc"}, + {file = "openai-1.52.2.tar.gz", hash = "sha256:87b7d0f69d85f5641678d414b7ee3082363647a5c66a462ed7f3ccb59582da0d"}, ] [package.dependencies] @@ -6638,13 +6697,13 @@ wcwidth = "*" [[package]] name = "proto-plus" -version = "1.24.0" +version = "1.25.0" description = "Beautiful, Pythonic protocol buffers." optional = false python-versions = ">=3.7" files = [ - {file = "proto-plus-1.24.0.tar.gz", hash = "sha256:30b72a5ecafe4406b0d339db35b56c4059064e69227b8c3bda7462397f966445"}, - {file = "proto_plus-1.24.0-py3-none-any.whl", hash = "sha256:402576830425e5f6ce4c2a6702400ac79897dab0b4343821aa5188b0fab81a12"}, + {file = "proto_plus-1.25.0-py3-none-any.whl", hash = "sha256:c91fc4a65074ade8e458e95ef8bac34d4008daa7cce4a12d6707066fca648961"}, + {file = "proto_plus-1.25.0.tar.gz", hash = "sha256:fbb17f57f7bd05a68b7707e745e26528b0b3c34e378db91eef93912c54982d91"}, ] [package.dependencies] @@ -6831,6 +6890,19 @@ files = [ {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] @@ -7240,6 +7312,27 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] +[[package]] +name = "pypdf" +version = "5.0.1" +description = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pypdf-5.0.1-py3-none-any.whl", hash = "sha256:ff8a32da6c7a63fea9c32fa4dd837cdd0db7966adf6c14f043e3f12592e992db"}, + {file = "pypdf-5.0.1.tar.gz", hash = "sha256:a361c3c372b4a659f9c8dd438d5ce29a753c79c620dc6e1fd66977651f5547ea"}, +] + +[package.dependencies] +typing_extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} + +[package.extras] +crypto = ["PyCryptodome", "cryptography"] +dev = ["black", "flit", "pip-tools", "pre-commit (<2.18.0)", "pytest-cov", "pytest-socket", "pytest-timeout", "pytest-xdist", "wheel"] +docs = ["myst_parser", "sphinx", "sphinx_rtd_theme"] +full = ["Pillow (>=8.0.0)", "PyCryptodome", "cryptography"] +image = ["Pillow (>=8.0.0)"] + [[package]] name = "pypdfium2" version = "4.17.0" @@ -7495,13 +7588,13 @@ files = [ [[package]] name = "python-dateutil" -version = "2.9.0.post0" +version = "2.8.2" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, - {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, ] [package.dependencies] @@ -7562,19 +7655,36 @@ files = [ ] [[package]] -name = "python-pptx" -version = "0.6.23" -description = "Generate and manipulate Open XML PowerPoint (.pptx) files" +name = "python-oxmsg" +version = "0.0.1" +description = "Extract attachments from Outlook .msg files." optional = false -python-versions = "*" +python-versions = ">=3.9" files = [ - {file = "python-pptx-0.6.23.tar.gz", hash = "sha256:587497ff28e779ab18dbb074f6d4052893c85dedc95ed75df319364f331fedee"}, - {file = "python_pptx-0.6.23-py3-none-any.whl", hash = "sha256:dd0527194627a2b7cc05f3ba23ecaa2d9a0d5ac9b6193a28ed1b7a716f4217d4"}, + {file = "python_oxmsg-0.0.1-py3-none-any.whl", hash = "sha256:8ea7d5dda1bc161a413213da9e18ed152927c1fda2feaf5d1f02192d8ad45eea"}, + {file = "python_oxmsg-0.0.1.tar.gz", hash = "sha256:b65c1f93d688b85a9410afa824192a1ddc39da359b04a0bd2cbd3874e84d4994"}, +] + +[package.dependencies] +click = "*" +olefile = "*" +typing-extensions = ">=4.9.0" + +[[package]] +name = "python-pptx" +version = "1.0.2" +description = "Create, read, and update PowerPoint 2007+ (.pptx) files." +optional = false +python-versions = ">=3.8" +files = [ + {file = "python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba"}, + {file = "python_pptx-1.0.2.tar.gz", hash = "sha256:479a8af0eaf0f0d76b6f00b0887732874ad2e3188230315290cd1f9dd9cc7095"}, ] [package.dependencies] lxml = ">=3.1.0" Pillow = ">=3.3.2" +typing-extensions = ">=4.9.0" XlsxWriter = ">=0.5.7" [[package]] @@ -9143,13 +9253,13 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tencentcloud-sdk-python-common" -version = "3.0.1256" +version = "3.0.1257" description = "Tencent Cloud Common SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-common-3.0.1256.tar.gz", hash = "sha256:83c7b4154f502d67741486cbeae99e734aaf0bfe7eb2cdc9e77916cce9017f17"}, - {file = "tencentcloud_sdk_python_common-3.0.1256-py2.py3-none-any.whl", hash = "sha256:2639f3510743003add35f97c7717680ff357b5324ce3bd67466c19dafe4e7f0b"}, + {file = "tencentcloud-sdk-python-common-3.0.1257.tar.gz", hash = "sha256:e10b155d598a60c43a491be10f40f7dae5774a2187d55f2da83bdb559434f3c4"}, + {file = "tencentcloud_sdk_python_common-3.0.1257-py2.py3-none-any.whl", hash = "sha256:f474a2969f3cbff91f45780f18bfbb90ab53f66c0085c4e9b4e07c2fcf0e71d9"}, ] [package.dependencies] @@ -9157,17 +9267,17 @@ requests = ">=2.16.0" [[package]] name = "tencentcloud-sdk-python-hunyuan" -version = "3.0.1256" +version = "3.0.1257" description = "Tencent Cloud Hunyuan SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-hunyuan-3.0.1256.tar.gz", hash = "sha256:bbc72cd4ef8aacde6ff186521be9ba62eb795eab673fa8e87a4e51e91d33cb95"}, - {file = "tencentcloud_sdk_python_hunyuan-3.0.1256-py2.py3-none-any.whl", hash = "sha256:fbc2e66aac8a9279c560e041a843c2a6884c9145846e6215cde59fa0a127e7e1"}, + {file = "tencentcloud-sdk-python-hunyuan-3.0.1257.tar.gz", hash = "sha256:4d38505089bed70dda1f806f8c4835f8a8c520efa86dcecfef444045c21b695d"}, + {file = "tencentcloud_sdk_python_hunyuan-3.0.1257-py2.py3-none-any.whl", hash = "sha256:c9089d3e49304c9c20e7465c82372b2cd234e67f63efdffb6798a4093b3a97c6"}, ] [package.dependencies] -tencentcloud-sdk-python-common = "3.0.1256" +tencentcloud-sdk-python-common = "3.0.1257" [[package]] name = "threadpoolctl" @@ -9703,13 +9813,13 @@ files = [ [[package]] name = "unstructured" -version = "0.10.30" +version = "0.16.1" description = "A library that prepares raw documents for downstream ML tasks." optional = false -python-versions = ">=3.7.0" +python-versions = "<3.13,>=3.9.0" files = [ - {file = "unstructured-0.10.30-py3-none-any.whl", hash = "sha256:0615f14daa37450e9c0fcf3c3fd178c3a06b6b8d006a36d1a5e54dbe487aa6b6"}, - {file = "unstructured-0.10.30.tar.gz", hash = "sha256:a86c3d15c572a28322d83cb5ecf0ac7a24f1c36864fb7c68df096de8a1acc106"}, + {file = "unstructured-0.16.1-py3-none-any.whl", hash = "sha256:7512281a2917809a563cbb186876b77d5a361e1f3089eca61e9219aecd1218f9"}, + {file = "unstructured-0.16.1.tar.gz", hash = "sha256:03608b5189a004412cd618ce2d083ff926c56dbbca41b41c92e08ffa9e2bac3a"}, ] [package.dependencies] @@ -9719,71 +9829,70 @@ chardet = "*" dataclasses-json = "*" emoji = "*" filetype = "*" +html5lib = "*" langdetect = "*" lxml = "*" markdown = {version = "*", optional = true, markers = "extra == \"md\""} -msg-parser = {version = "*", optional = true, markers = "extra == \"msg\""} nltk = "*" -numpy = "*" +numpy = "<2" +psutil = "*" pypandoc = {version = "*", optional = true, markers = "extra == \"epub\""} -python-docx = {version = ">=1.1.0", optional = true, markers = "extra == \"docx\""} +python-docx = {version = ">=1.1.2", optional = true, markers = "extra == \"docx\""} python-iso639 = "*" python-magic = "*" -python-pptx = {version = "<=0.6.23", optional = true, markers = "extra == \"ppt\" or extra == \"pptx\""} +python-oxmsg = "*" +python-pptx = {version = ">=1.0.1", optional = true, markers = "extra == \"ppt\" or extra == \"pptx\""} rapidfuzz = "*" requests = "*" -tabulate = "*" +tqdm = "*" typing-extensions = "*" +unstructured-client = "*" +wrapt = "*" [package.extras] -airtable = ["pyairtable"] -all-docs = ["markdown", "msg-parser", "networkx", "onnx", "openpyxl", "pandas", "pdf2image", "pdfminer.six", "pypandoc", "python-docx (>=1.1.0)", "python-pptx (<=0.6.23)", "unstructured-inference (==0.7.11)", "unstructured.pytesseract (>=0.3.12)", "xlrd"] -azure = ["adlfs", "fsspec (==2023.9.1)"] -azure-cognitive-search = ["azure-search-documents"] -bedrock = ["boto3", "langchain"] -biomed = ["bs4"] -box = ["boxfs", "fsspec (==2023.9.1)"] -confluence = ["atlassian-python-api"] +all-docs = ["effdet", "google-cloud-vision", "markdown", "networkx", "onnx", "openpyxl", "pandas", "pdf2image", "pdfminer.six", "pi-heif", "pikepdf", "pypandoc", "pypdf", "python-docx (>=1.1.2)", "python-pptx (>=1.0.1)", "unstructured-inference (==0.8.0)", "unstructured.pytesseract (>=0.3.12)", "xlrd"] csv = ["pandas"] -delta-table = ["deltalake", "fsspec (==2023.9.1)"] -discord = ["discord-py"] -doc = ["python-docx (>=1.1.0)"] -docx = ["python-docx (>=1.1.0)"] -dropbox = ["dropboxdrivefs", "fsspec (==2023.9.1)"] -elasticsearch = ["elasticsearch", "jq"] -embed-huggingface = ["huggingface", "langchain", "sentence-transformers"] +doc = ["python-docx (>=1.1.2)"] +docx = ["python-docx (>=1.1.2)"] epub = ["pypandoc"] -gcs = ["bs4", "fsspec (==2023.9.1)", "gcsfs"] -github = ["pygithub (>1.58.0)"] -gitlab = ["python-gitlab"] -google-drive = ["google-api-python-client"] huggingface = ["langdetect", "sacremoses", "sentencepiece", "torch", "transformers"] -image = ["onnx", "pdf2image", "pdfminer.six", "unstructured-inference (==0.7.11)", "unstructured.pytesseract (>=0.3.12)"] -jira = ["atlassian-python-api"] -local-inference = ["markdown", "msg-parser", "networkx", "onnx", "openpyxl", "pandas", "pdf2image", "pdfminer.six", "pypandoc", "python-docx (>=1.1.0)", "python-pptx (<=0.6.23)", "unstructured-inference (==0.7.11)", "unstructured.pytesseract (>=0.3.12)", "xlrd"] +image = ["effdet", "google-cloud-vision", "onnx", "pdf2image", "pdfminer.six", "pi-heif", "pikepdf", "pypdf", "unstructured-inference (==0.8.0)", "unstructured.pytesseract (>=0.3.12)"] +local-inference = ["effdet", "google-cloud-vision", "markdown", "networkx", "onnx", "openpyxl", "pandas", "pdf2image", "pdfminer.six", "pi-heif", "pikepdf", "pypandoc", "pypdf", "python-docx (>=1.1.2)", "python-pptx (>=1.0.1)", "unstructured-inference (==0.8.0)", "unstructured.pytesseract (>=0.3.12)", "xlrd"] md = ["markdown"] -msg = ["msg-parser"] -notion = ["htmlBuilder", "notion-client"] -odt = ["pypandoc", "python-docx (>=1.1.0)"] -onedrive = ["Office365-REST-Python-Client (<2.4.3)", "bs4", "msal"] -openai = ["langchain", "openai", "tiktoken"] +odt = ["pypandoc", "python-docx (>=1.1.2)"] org = ["pypandoc"] -outlook = ["Office365-REST-Python-Client (<2.4.3)", "msal"] -paddleocr = ["unstructured.paddleocr (==2.6.1.3)"] -pdf = ["onnx", "pdf2image", "pdfminer.six", "unstructured-inference (==0.7.11)", "unstructured.pytesseract (>=0.3.12)"] -ppt = ["python-pptx (<=0.6.23)"] -pptx = ["python-pptx (<=0.6.23)"] -reddit = ["praw"] +paddleocr = ["paddlepaddle (==3.0.0b1)", "unstructured.paddleocr (==2.8.1.0)"] +pdf = ["effdet", "google-cloud-vision", "onnx", "pdf2image", "pdfminer.six", "pi-heif", "pikepdf", "pypdf", "unstructured-inference (==0.8.0)", "unstructured.pytesseract (>=0.3.12)"] +ppt = ["python-pptx (>=1.0.1)"] +pptx = ["python-pptx (>=1.0.1)"] rst = ["pypandoc"] rtf = ["pypandoc"] -s3 = ["fsspec (==2023.9.1)", "s3fs"] -salesforce = ["simple-salesforce"] -sharepoint = ["Office365-REST-Python-Client (<2.4.3)", "msal"] -slack = ["slack-sdk"] tsv = ["pandas"] -wikipedia = ["wikipedia"] xlsx = ["networkx", "openpyxl", "pandas", "xlrd"] +[[package]] +name = "unstructured-client" +version = "0.26.1" +description = "Python Client SDK for Unstructured API" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "unstructured_client-0.26.1-py3-none-any.whl", hash = "sha256:b8b839d477122bab3f37242cbe44b39f7eb7b564b07b53500321f953710119b6"}, + {file = "unstructured_client-0.26.1.tar.gz", hash = "sha256:907cceb470529b45b0fddb2d0f1bbf4d6568f347c757ab68639a7bb620ec2484"}, +] + +[package.dependencies] +cryptography = ">=3.1" +eval-type-backport = ">=0.2.0,<0.3.0" +httpx = ">=0.27.0" +jsonpath-python = ">=1.0.6,<2.0.0" +nest-asyncio = ">=1.6.0" +pydantic = ">=2.9.0,<2.10.0" +pypdf = ">=4.0" +python-dateutil = "2.8.2" +requests-toolbelt = ">=1.0.0" +typing-inspect = ">=0.9.0,<0.10.0" + [[package]] name = "upstash-vector" version = "0.6.0" @@ -10643,48 +10752,48 @@ test = ["zope.testrunner"] [[package]] name = "zope-interface" -version = "7.1.0" +version = "7.1.1" description = "Interfaces for Python" optional = false python-versions = ">=3.8" files = [ - {file = "zope.interface-7.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2bd9e9f366a5df08ebbdc159f8224904c1c5ce63893984abb76954e6fbe4381a"}, - {file = "zope.interface-7.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:661d5df403cd3c5b8699ac480fa7f58047a3253b029db690efa0c3cf209993ef"}, - {file = "zope.interface-7.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91b6c30689cfd87c8f264acb2fc16ad6b3c72caba2aec1bf189314cf1a84ca33"}, - {file = "zope.interface-7.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b6a4924f5bad9fe21d99f66a07da60d75696a136162427951ec3cb223a5570d"}, - {file = "zope.interface-7.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80a3c00b35f6170be5454b45abe2719ea65919a2f09e8a6e7b1362312a872cd3"}, - {file = "zope.interface-7.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:b936d61dbe29572fd2cfe13e30b925e5383bed1aba867692670f5a2a2eb7b4e9"}, - {file = "zope.interface-7.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0ac20581fc6cd7c754f6dff0ae06fedb060fa0e9ea6309d8be8b2701d9ea51c4"}, - {file = "zope.interface-7.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:848b6fa92d7c8143646e64124ed46818a0049a24ecc517958c520081fd147685"}, - {file = "zope.interface-7.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec1ef1fdb6f014d5886b97e52b16d0f852364f447d2ab0f0c6027765777b6667"}, - {file = "zope.interface-7.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bcff5c09d0215f42ba64b49205a278e44413d9bf9fa688fd9e42bfe472b5f4f"}, - {file = "zope.interface-7.1.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07add15de0cc7e69917f7d286b64d54125c950aeb43efed7a5ea7172f000fbc1"}, - {file = "zope.interface-7.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:9940d5bc441f887c5f375ec62bcf7e7e495a2d5b1da97de1184a88fb567f06af"}, - {file = "zope.interface-7.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f245d039f72e6f802902375755846f5de1ee1e14c3e8736c078565599bcab621"}, - {file = "zope.interface-7.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6159e767d224d8f18deff634a1d3722e68d27488c357f62ebeb5f3e2f5288b1f"}, - {file = "zope.interface-7.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e956b1fd7f3448dd5e00f273072e73e50dfafcb35e4227e6d5af208075593c9"}, - {file = "zope.interface-7.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff115ef91c0eeac69cd92daeba36a9d8e14daee445b504eeea2b1c0b55821984"}, - {file = "zope.interface-7.1.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bec001798ab62c3fc5447162bf48496ae9fba02edc295a9e10a0b0c639a6452e"}, - {file = "zope.interface-7.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:124149e2d42067b9c6597f4dafdc7a0983d0163868f897b7bb5dc850b14f9a87"}, - {file = "zope.interface-7.1.0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:9733a9a0f94ef53d7aa64661811b20875b5bc6039034c6e42fb9732170130573"}, - {file = "zope.interface-7.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5fcf379b875c610b5a41bc8a891841533f98de0520287d7f85e25386cd10d3e9"}, - {file = "zope.interface-7.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0a45b5af9f72c805ee668d1479480ca85169312211bed6ed18c343e39307d5f"}, - {file = "zope.interface-7.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4af4a12b459a273b0b34679a5c3dc5e34c1847c3dd14a628aa0668e19e638ea2"}, - {file = "zope.interface-7.1.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a735f82d2e3ed47ca01a20dfc4c779b966b16352650a8036ab3955aad151ed8a"}, - {file = "zope.interface-7.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:5501e772aff595e3c54266bc1bfc5858e8f38974ce413a8f1044aae0f32a83a3"}, - {file = "zope.interface-7.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec59fe53db7d32abb96c6d4efeed84aab4a7c38c62d7a901a9b20c09dd936e7a"}, - {file = "zope.interface-7.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e53c291debef523b09e1fe3dffe5f35dde164f1c603d77f770b88a1da34b7ed6"}, - {file = "zope.interface-7.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:711eebc77f2092c6a8b304bad0b81a6ce3cf5490b25574e7309fbc07d881e3af"}, - {file = "zope.interface-7.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a00ead2e24c76436e1b457a5132d87f83858330f6c923640b7ef82d668525d1"}, - {file = "zope.interface-7.1.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e28ea0bc4b084fc93a483877653a033062435317082cdc6388dec3438309faf"}, - {file = "zope.interface-7.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:27cfb5205d68b12682b6e55ab8424662d96e8ead19550aad0796b08dd2c9a45e"}, - {file = "zope.interface-7.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9e3e48f3dea21c147e1b10c132016cb79af1159facca9736d231694ef5a740a8"}, - {file = "zope.interface-7.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a99240b1d02dc469f6afbe7da1bf617645e60290c272968f4e53feec18d7dce8"}, - {file = "zope.interface-7.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc8a318162123eddbdf22fcc7b751288ce52e4ad096d3766ff1799244352449d"}, - {file = "zope.interface-7.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7b25db127db3e6b597c5f74af60309c4ad65acd826f89609662f0dc33a54728"}, - {file = "zope.interface-7.1.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a29ac607e970b5576547f0e3589ec156e04de17af42839eedcf478450687317"}, - {file = "zope.interface-7.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:a14c9decf0eb61e0892631271d500c1e306c7b6901c998c7035e194d9150fdd1"}, - {file = "zope_interface-7.1.0.tar.gz", hash = "sha256:3f005869a1a05e368965adb2075f97f8ee9a26c61898a9e52a9764d93774f237"}, + {file = "zope.interface-7.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6650bd56ef350d37c8baccfd3ee8a0483ed6f8666e641e4b9ae1a1827b79f9e5"}, + {file = "zope.interface-7.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:84e87eba6b77a3af187bae82d8de1a7c208c2a04ec9f6bd444fd091b811ad92e"}, + {file = "zope.interface-7.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c4e1b4c06d9abd1037c088dae1566c85f344a3e6ae4350744c3f7f7259d9c67"}, + {file = "zope.interface-7.1.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7cd5e3d910ac87652a09f6e5db8e41bc3b49cf08ddd2d73d30afc644801492cd"}, + {file = "zope.interface-7.1.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca95594d936ee349620900be5b46c0122a1ff6ce42d7d5cb2cf09dc84071ef16"}, + {file = "zope.interface-7.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:ad339509dcfbbc99bf8e147db6686249c4032f26586699ec4c82f6e5909c9fe2"}, + {file = "zope.interface-7.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e59f175e868f856a77c0a77ba001385c377df2104fdbda6b9f99456a01e102a"}, + {file = "zope.interface-7.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0de23bcb93401994ea00bc5c677ef06d420340ac0a4e9c10d80e047b9ce5af3f"}, + {file = "zope.interface-7.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cdb7e7e5524b76d3ec037c1d81a9e2c7457b240fd4cb0a2476b65c3a5a6c81f"}, + {file = "zope.interface-7.1.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3603ef82a9920bd0bfb505423cb7e937498ad971ad5a6141841e8f76d2fd5446"}, + {file = "zope.interface-7.1.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1d52d052355e0c5c89e0630dd2ff7c0b823fd5f56286a663e92444761b35e25"}, + {file = "zope.interface-7.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:179ad46ece518c9084cb272e4a69d266b659f7f8f48e51706746c2d8a426433e"}, + {file = "zope.interface-7.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e6503534b52bb1720ace9366ee30838a58a3413d3e197512f3338c8f34b5d89d"}, + {file = "zope.interface-7.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f85b290e5b8b11814efb0d004d8ce6c9a483c35c462e8d9bf84abb93e79fa770"}, + {file = "zope.interface-7.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d029fac6a80edae80f79c37e5e3abfa92968fe921886139b3ee470a1b177321a"}, + {file = "zope.interface-7.1.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5836b8fb044c6e75ba34dfaabc602493019eadfa0faf6ff25f4c4c356a71a853"}, + {file = "zope.interface-7.1.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7395f13533318f150ee72adb55b29284b16e73b6d5f02ab21f173b3e83f242b8"}, + {file = "zope.interface-7.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:1d0e23c6b746eb8ce04573cc47bcac60961ac138885d207bd6f57e27a1431ae8"}, + {file = "zope.interface-7.1.1-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:9fad9bd5502221ab179f13ea251cb30eef7cf65023156967f86673aff54b53a0"}, + {file = "zope.interface-7.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:55c373becbd36a44d0c9be1d5271422fdaa8562d158fb44b4192297b3c67096c"}, + {file = "zope.interface-7.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed1df8cc01dd1e3970666a7370b8bfc7457371c58ba88c57bd5bca17ab198053"}, + {file = "zope.interface-7.1.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99c14f0727c978639139e6cad7a60e82b7720922678d75aacb90cf4ef74a068c"}, + {file = "zope.interface-7.1.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b1eed7670d564f1025d7cda89f99f216c30210e42e95de466135be0b4a499d9"}, + {file = "zope.interface-7.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:3defc925c4b22ac1272d544a49c6ba04c3eefcce3200319ee1be03d9270306dd"}, + {file = "zope.interface-7.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8d0fe45be57b5219aa4b96e846631c04615d5ef068146de5a02ccd15c185321f"}, + {file = "zope.interface-7.1.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bcbeb44fc16e0078b3b68a95e43f821ae34dcbf976dde6985141838a5f23dd3d"}, + {file = "zope.interface-7.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8e7b05dc6315a193cceaec071cc3cf1c180cea28808ccded0b1283f1c38ba73"}, + {file = "zope.interface-7.1.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d553e02b68c0ea5a226855f02edbc9eefd99f6a8886fa9f9bdf999d77f46585"}, + {file = "zope.interface-7.1.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81744a7e61b598ebcf4722ac56a7a4f50502432b5b4dc7eb29075a89cf82d029"}, + {file = "zope.interface-7.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:7720322763aceb5e0a7cadcc38c67b839efe599f0887cbf6c003c55b1458c501"}, + {file = "zope.interface-7.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1a2ed0852c25950cf430067f058f8d98df6288502ac313861d9803fe7691a9b3"}, + {file = "zope.interface-7.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9595e478047ce752b35cfa221d7601a5283ccdaab40422e0dc1d4a334c70f580"}, + {file = "zope.interface-7.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2317e1d4dba68203a5227ea3057f9078ec9376275f9700086b8f0ffc0b358e1b"}, + {file = "zope.interface-7.1.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6821ef9870f32154da873fcde439274f99814ea452dd16b99fa0b66345c4b6b"}, + {file = "zope.interface-7.1.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:190eeec67e023d5aac54d183fa145db0b898664234234ac54643a441da434616"}, + {file = "zope.interface-7.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:d17e7fc814eaab93409b80819fd6d30342844345c27f3bc3c4b43c2425a8d267"}, + {file = "zope.interface-7.1.1.tar.gz", hash = "sha256:4284d664ef0ff7b709836d4de7b13d80873dc5faeffc073abdb280058bfac5e3"}, ] [package.dependencies] @@ -10810,4 +10919,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "32fd52006f75e42fbc8f787e559a72f4e033383c73225231e4ecadabfec926f7" +content-hash = "52552faf5f4823056eb48afe05349ab2f0e9a5bc42105211ccbbb54b59e27b59" diff --git a/api/pyproject.toml b/api/pyproject.toml index e9529a192e..a3313f0ff5 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -172,11 +172,12 @@ sagemaker = "2.231.0" scikit-learn = "~1.5.1" sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sqlalchemy = "~2.0.29" +starlette = "0.41.0" tencentcloud-sdk-python-hunyuan = "~3.0.1158" tiktoken = "~0.8.0" tokenizers = "~0.15.0" transformers = "~4.35.0" -unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } +unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } validators = "0.21.0" volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} websocket-client = "~1.7.0" @@ -206,7 +207,7 @@ duckduckgo-search = "~6.3.0" jsonpath-ng = "1.6.1" matplotlib = "~3.8.2" newspaper3k = "0.2.8" -nltk = "3.8.1" +nltk = "3.9.1" numexpr = "~2.9.0" pydub = "~0.25.1" qrcode = "~7.4.2" @@ -238,6 +239,7 @@ alibabacloud_gpdb20160503 = "~3.8.0" alibabacloud_tea_openapi = "~0.3.9" chromadb = "0.5.1" clickhouse-connect = "~0.7.16" +couchbase = "~4.3.0" elasticsearch = "8.14.0" opensearch-py = "2.4.0" oracledb = "~2.2.1" diff --git a/api/pytest.ini b/api/pytest.ini index dcca08e2e5..a23a4b3f3d 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -27,3 +27,4 @@ env = XINFERENCE_GENERATION_MODEL_UID = generate XINFERENCE_RERANK_MODEL_UID = rerank XINFERENCE_SERVER_URL = http://a.abc.com:11451 + GITEE_AI_API_KEY = aaaaaaaaaaaaaaaaaaaa diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py new file mode 100644 index 0000000000..42d6c04beb --- /dev/null +++ b/api/schedule/create_tidb_serverless_task.py @@ -0,0 +1,56 @@ +import time + +import click + +import app +from configs import dify_config +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from extensions.ext_database import db +from models.dataset import TidbAuthBinding + + +@app.celery.task(queue="dataset") +def create_tidb_serverless_task(): + click.echo(click.style("Start create tidb serverless task.", fg="green")) + tidb_serverless_number = dify_config.TIDB_SERVERLESS_NUMBER + start_at = time.perf_counter() + while True: + try: + # check the number of idle tidb serverless + idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() + if idle_tidb_serverless_number >= tidb_serverless_number: + break + # create tidb serverless + iterations_per_thread = 20 + create_clusters(iterations_per_thread) + + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) + break + + end_at = time.perf_counter() + click.echo(click.style("Create tidb serverless task success latency: {}".format(end_at - start_at), fg="green")) + + +def create_clusters(batch_size): + try: + new_clusters = TidbService.batch_create_tidb_serverless_cluster( + batch_size, + dify_config.TIDB_PROJECT_ID, + dify_config.TIDB_API_URL, + dify_config.TIDB_IAM_API_URL, + dify_config.TIDB_PUBLIC_KEY, + dify_config.TIDB_PRIVATE_KEY, + dify_config.TIDB_REGION, + ) + for new_cluster in new_clusters: + tidb_auth_binding = TidbAuthBinding( + cluster_id=new_cluster["cluster_id"], + cluster_name=new_cluster["cluster_name"], + account=new_cluster["account"], + password=new_cluster["password"], + ) + db.session.add(tidb_auth_binding) + db.session.commit() + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py new file mode 100644 index 0000000000..07eca3173b --- /dev/null +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -0,0 +1,51 @@ +import time + +import click + +import app +from configs import dify_config +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from models.dataset import TidbAuthBinding + + +@app.celery.task(queue="dataset") +def update_tidb_serverless_status_task(): + click.echo(click.style("Update tidb serverless status task.", fg="green")) + start_at = time.perf_counter() + while True: + try: + # check the number of idle tidb serverless + tidb_serverless_list = TidbAuthBinding.query.filter( + TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" + ).all() + if len(tidb_serverless_list) == 0: + break + # update tidb serverless status + iterations_per_thread = 20 + update_clusters(tidb_serverless_list) + + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) + break + + end_at = time.perf_counter() + click.echo( + click.style("Update tidb serverless status task success latency: {}".format(end_at - start_at), fg="green") + ) + + +def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): + try: + # batch 20 + for i in range(0, len(tidb_serverless_list), 20): + items = tidb_serverless_list[i : i + 20] + TidbService.batch_update_tidb_serverless_cluster_status( + items, + dify_config.TIDB_PROJECT_ID, + dify_config.TIDB_API_URL, + dify_config.TIDB_IAM_API_URL, + dify_config.TIDB_PUBLIC_KEY, + dify_config.TIDB_PRIVATE_KEY, + ) + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/services/account_service.py b/api/services/account_service.py index 412685147c..dceca06185 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -486,9 +486,13 @@ def _get_login_cache_key(*, account_id: str, token: str): class TenantService: @staticmethod - def create_tenant(name: str, is_setup: Optional[bool] = False) -> Tenant: + def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant: """Create tenant""" - if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: + if ( + not FeatureService.get_system_features().is_allow_create_workspace + and not is_setup + and not is_from_dashboard + ): from controllers.console.error import NotAllowedCreateWorkspace raise NotAllowedCreateWorkspace() @@ -505,9 +509,7 @@ class TenantService: def create_owner_tenant_if_not_exist( account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False ): - """Create owner tenant if not exist""" - if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: - raise WorkSpaceNotAllowedCreateError() + """Check if user have a workspace or not""" available_ta = ( TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() ) @@ -515,6 +517,10 @@ class TenantService: if available_ta: return + """Create owner tenant if not exist""" + if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: + raise WorkSpaceNotAllowedCreateError() + if name: tenant = TenantService.create_tenant(name=name, is_setup=is_setup) else: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 3cc6c51c2d..915d37ec03 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -132,14 +132,14 @@ class AppAnnotationService: MessageAnnotation.content.ilike("%{}%".format(keyword)), ) ) - .order_by(MessageAnnotation.created_at.desc()) + .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) else: annotations = ( db.session.query(MessageAnnotation) .filter(MessageAnnotation.app_id == app_id) - .order_by(MessageAnnotation.created_at.desc()) + .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) return annotations.items, annotations.total diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py new file mode 100644 index 0000000000..de898a1f94 --- /dev/null +++ b/api/services/auth/jina.py @@ -0,0 +1,44 @@ +import json + +import requests + +from services.auth.api_key_auth_base import ApiKeyAuthBase + + +class JinaAuth(ApiKeyAuthBase): + def __init__(self, credentials: dict): + super().__init__(credentials) + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") + self.api_key = credentials.get("config").get("api_key", None) + + if not self.api_key: + raise ValueError("No API key provided") + + def validate_credentials(self): + headers = self._prepare_headers() + options = { + "url": "https://example.com", + } + response = self._post_request("https://r.jina.ai", options, headers) + if response.status_code == 200: + return True + else: + self._handle_error(response) + + def _prepare_headers(self): + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + + def _post_request(self, url, data, headers): + return requests.post(url, headers=headers, json=data) + + def _handle_error(self, response): + if response.status_code in {402, 409, 500}: + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + else: + if response.text: + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ede8764086..414ef0224a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -140,6 +140,7 @@ class DatasetService: def create_empty_dataset( tenant_id: str, name: str, + description: Optional[str], indexing_technique: Optional[str], account: Account, permission: Optional[str] = None, @@ -158,6 +159,7 @@ class DatasetService: ) dataset = Dataset(name=name, indexing_technique=indexing_technique) # dataset = Dataset(name=name, provider=provider, config=config) + dataset.description = description dataset.created_by = account.id dataset.updated_by = account.id dataset.tenant_id = tenant_id @@ -758,166 +760,168 @@ class DocumentService: ) db.session.add(dataset_process_rule) db.session.commit() - position = DocumentService.get_documents_position(dataset.id) - document_ids = [] - duplicate_document_ids = [] - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] - for file_id in upload_file_list: - file = ( - db.session.query(UploadFile) - .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() - ) - - # raise error if file not found - if not file: - raise FileNotExistsError() - - file_name = file.name - data_source_info = { - "upload_file_id": file_id, - } - # check duplicate - if document_data.get("duplicate", False): - document = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ).first() - if document: - document.dataset_process_rule_id = dataset_process_rule.id - document.updated_at = datetime.datetime.utcnow() - document.created_from = created_from - document.doc_form = document_data["doc_form"] - document.doc_language = document_data["doc_language"] - document.data_source_info = json.dumps(data_source_info) - document.batch = batch - document.indexing_status = "waiting" - db.session.add(document) - documents.append(document) - duplicate_document_ids.append(document.id) - continue - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], - data_source_info, - created_from, - position, - account, - file_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] - exist_page_ids = [] - exist_document = {} - documents = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", - enabled=True, - ).all() - if documents: - for document in documents: - data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info["notion_page_id"]) - exist_document[data_source_info["notion_page_id"]] = document.id - for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + with redis_client.lock(lock_name, timeout=600): + position = DocumentService.get_documents_position(dataset.id) + document_ids = [] + duplicate_document_ids = [] + if document_data["data_source"]["type"] == "upload_file": + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + for file_id in upload_file_list: + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() ) - ).first() - if not data_source_binding: - raise ValueError("Data source binding not found.") - for page in notion_info["pages"]: - if page["page_id"] not in exist_page_ids: - data_source_info = { - "notion_workspace_id": workspace_id, - "notion_page_id": page["page_id"], - "notion_page_icon": page["page_icon"], - "type": page["type"], - } - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], - data_source_info, - created_from, - position, - account, - page["page_name"], - batch, + + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + # check duplicate + if document_data.get("duplicate", False): + document = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ).first() + if document: + document.dataset_process_rule_id = dataset_process_rule.id + document.updated_at = datetime.datetime.utcnow() + document.created_from = created_from + document.doc_form = document_data["doc_form"] + document.doc_language = document_data["doc_language"] + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + db.session.add(document) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, + created_from, + position, + account, + file_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + elif document_data["data_source"]["type"] == "notion_import": + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + exist_page_ids = [] + exist_document = {} + documents = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ).all() + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id + for notion_info in notion_info_list: + workspace_id = notion_info["workspace_id"] + data_source_binding = DataSourceOauthBinding.query.filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 + ).first() + if not data_source_binding: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: + if page["page_id"] not in exist_page_ids: + data_source_info = { + "notion_workspace_id": workspace_id, + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], + } + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, + created_from, + position, + account, + page["page_name"], + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + else: + exist_document.pop(page["page_id"]) + # delete not selected documents + if len(exist_document) > 0: + clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + elif document_data["data_source"]["type"] == "website_crawl": + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] + for url in urls: + data_source_info = { + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", + } + if len(url) > 255: + document_name = url[:200] + "..." else: - exist_document.pop(page["page_id"]) - # delete not selected documents - if len(exist_document) > 0: - clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - urls = website_info["urls"] - for url in urls: - data_source_info = { - "url": url, - "provider": website_info["provider"], - "job_id": website_info["job_id"], - "only_main_content": website_info.get("only_main_content", False), - "mode": "crawl", - } - if len(url) > 255: - document_name = url[:200] + "..." - else: - document_name = url - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], - data_source_info, - created_from, - position, - account, - document_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - db.session.commit() + document_name = url + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, + created_from, + position, + account, + document_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + db.session.commit() - # trigger async task - if document_ids: - document_indexing_task.delay(dataset.id, document_ids) - if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + # trigger async task + if document_ids: + document_indexing_task.delay(dataset.id, document_ids) + if duplicate_document_ids: + duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) - return documents, batch + return documents, batch @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 4efdf8d7db..b49738c61c 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -6,6 +6,8 @@ from typing import Any, Optional, Union import httpx import validators +from constants import HIDDEN_VALUE + # from tasks.external_document_indexing_task import external_document_indexing_task from core.helper import ssrf_proxy from extensions.ext_database import db @@ -68,7 +70,7 @@ class ExternalDatasetService: endpoint = f"{settings['endpoint']}/retrieval" api_key = settings["api_key"] - if not validators.url(endpoint): + if not validators.url(endpoint, simple_host=True): raise ValueError(f"invalid endpoint: {endpoint}") try: response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"}) @@ -92,6 +94,8 @@ class ExternalDatasetService: ).first() if external_knowledge_api is None: raise ValueError("api template not found") + if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: + args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key") external_knowledge_api.name = args.get("name") external_knowledge_api.description = args.get("description", "") diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 2d52399d29..6791cd891b 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -83,3 +83,6 @@ VOLC_EMBEDDING_ENDPOINT_ID= # 360 AI Credentials ZHINAO_API_KEY= + +# Gitee AI Credentials +GITEE_AI_API_KEY= diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/__init__.py b/api/tests/integration_tests/model_runtime/gitee_ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py new file mode 100644 index 0000000000..753c52ce31 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py @@ -0,0 +1,132 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.llm.llm import GiteeAILargeLanguageModel + + +def test_predefined_models(): + model = GiteeAILargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +def test_validate_credentials_for_chat_model(): + model = GiteeAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + # model name to gpt-3.5-turbo because of mocking + model.validate_credentials(model="gpt-3.5-turbo", credentials={"api_key": "invalid_key"}) + + model.validate_credentials( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + ) + + +def test_invoke_chat_model(): + model = GiteeAILargeLanguageModel() + + result = model.invoke( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + "stream": False, + }, + stop=["How"], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +def test_invoke_stream_chat_model(): + model = GiteeAILargeLanguageModel() + + result = model.invoke( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100, "stream": False}, + stream=True, + user="foo", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + + +def test_get_num_tokens(): + model = GiteeAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 10 + + num_tokens = model.get_num_tokens( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 77 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py new file mode 100644 index 0000000000..f12ed54a45 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.gitee_ai import GiteeAIProvider + + +def test_validate_provider_credentials(): + provider = GiteeAIProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"api_key": "invalid_key"}) + + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py new file mode 100644 index 0000000000..0e5914a61f --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py @@ -0,0 +1,47 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.rerank.rerank import GiteeAIRerankModel + + +def test_validate_credentials(): + model = GiteeAIRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = GiteeAIRerankModel() + result = model.invoke( + model="bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + top_n=1, + score_threshold=0.01, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].score >= 0.01 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py new file mode 100644 index 0000000000..4a01453fdd --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py @@ -0,0 +1,45 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.speech2text.speech2text import GiteeAISpeech2TextModel + + +def test_validate_credentials(): + model = GiteeAISpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="whisper-base", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="whisper-base", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + ) + + +def test_invoke_model(): + model = GiteeAISpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + # Open the file and get the file object + with open(audio_file_path, "rb") as audio_file: + file = audio_file + + result = model.invoke( + model="whisper-base", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, file=file + ) + + assert isinstance(result, str) + assert result == "1 2 3 4 5 6 7 8 9 10" diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py new file mode 100644 index 0000000000..34648f0bc8 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py @@ -0,0 +1,46 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.text_embedding.text_embedding import GiteeAIEmbeddingModel + + +def test_validate_credentials(): + model = GiteeAIEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}) + + +def test_invoke_model(): + model = GiteeAIEmbeddingModel() + + result = model.invoke( + model="bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + texts=["hello", "world"], + user="user", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + + +def test_get_num_tokens(): + model = GiteeAIEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py new file mode 100644 index 0000000000..9f18161a7b --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py @@ -0,0 +1,23 @@ +import os + +from core.model_runtime.model_providers.gitee_ai.tts.tts import GiteeAIText2SpeechModel + + +def test_invoke_model(): + model = GiteeAIText2SpeechModel() + + result = model.invoke( + model="speecht5_tts", + tenant_id="test", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + content_text="Hello, world!", + voice="", + ) + + content = b"" + for chunk in result: + content += chunk + + assert content != b"" diff --git a/api/tests/integration_tests/vdb/couchbase/__init__.py b/api/tests/integration_tests/vdb/couchbase/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/couchbase/test_couchbase.py b/api/tests/integration_tests/vdb/couchbase/test_couchbase.py new file mode 100644 index 0000000000..d76c34ba0e --- /dev/null +++ b/api/tests/integration_tests/vdb/couchbase/test_couchbase.py @@ -0,0 +1,50 @@ +import subprocess +import time + +from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + get_example_text, + setup_mock_redis, +) + + +def wait_for_healthy_container(service_name="couchbase-server", timeout=300): + start_time = time.time() + while time.time() - start_time < timeout: + result = subprocess.run( + ["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True + ) + if result.stdout.strip() == "healthy": + print(f"{service_name} is healthy!") + return True + else: + print(f"Waiting for {service_name} to be healthy...") + time.sleep(10) + raise TimeoutError(f"{service_name} did not become healthy in time") + + +class CouchbaseTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = CouchbaseVector( + collection_name=self.collection_name, + config=CouchbaseConfig( + connection_string="couchbase://127.0.0.1", + user="Administrator", + password="password", + bucket_name="Embeddings", + scope_name="_default", + ), + ) + + def search_by_vector(self): + # brief sleep to ensure document is indexed + time.sleep(5) + hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 1 + + +def test_couchbase(setup_mock_redis): + wait_for_healthy_container("couchbase-server", timeout=60) + CouchbaseTest().run_all_tests() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 7471e13e1e..4f1f8f05c8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -63,17 +63,24 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s @pytest.mark.parametrize( - ("mime_type", "file_content", "expected_text", "transfer_method"), + ("mime_type", "file_content", "expected_text", "transfer_method", "extension"), [ - ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE), - ("application/pdf", b"%PDF-1.5\n%Test PDF content", ["Mocked PDF content"], FileTransferMethod.LOCAL_FILE), + ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"), + ( + "application/pdf", + b"%PDF-1.5\n%Test PDF content", + ["Mocked PDF content"], + FileTransferMethod.LOCAL_FILE, + ".pdf", + ), ( "application/vnd.openxmlformats-officedocument.wordprocessingml.document", b"PK\x03\x04", ["Mocked DOCX content"], - FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + "", ), - ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL), + ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None), ], ) def test_run_extract_text( @@ -83,6 +90,7 @@ def test_run_extract_text( file_content, expected_text, transfer_method, + extension, monkeypatch, ): document_extractor_node.graph_runtime_state = mock_graph_runtime_state @@ -92,6 +100,7 @@ def test_run_extract_text( mock_file.transfer_method = transfer_method mock_file.related_id = "test_file_id" if transfer_method == FileTransferMethod.LOCAL_FILE else None mock_file.remote_url = "https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None + mock_file.extension = extension mock_array_file_segment = Mock(spec=ArrayFileSegment) mock_array_file_segment.value = [mock_file] @@ -116,7 +125,7 @@ def test_run_extract_text( result = document_extractor_node._run() assert isinstance(result, NodeRunResult) - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error assert result.outputs is not None assert result.outputs["text"] == expected_text diff --git a/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py index 2a5fda48b1..720037d05f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py @@ -192,7 +192,7 @@ def test_http_request_node_form_with_file(monkeypatch): def attr_checker(*args, **kwargs): assert kwargs["data"] == {"name": "test"} - assert kwargs["files"] == {"file": b"test"} + assert kwargs["files"] == {"file": (None, b"test", "application/octet-stream")} return httpx.Response(200, content=b"") monkeypatch.setattr( diff --git a/api/tests/unit_tests/oss/__mock/base.py b/api/tests/unit_tests/oss/__mock/base.py new file mode 100644 index 0000000000..a1eaaab9c3 --- /dev/null +++ b/api/tests/unit_tests/oss/__mock/base.py @@ -0,0 +1,58 @@ +from collections.abc import Generator + +import pytest + +from extensions.storage.base_storage import BaseStorage + + +def get_example_folder() -> str: + return "/dify" + + +def get_example_bucket() -> str: + return "dify" + + +def get_example_filename() -> str: + return "test.txt" + + +def get_example_data() -> bytes: + return b"test" + + +def get_example_filepath() -> str: + return "/test" + + +class BaseStorageTest: + @pytest.fixture(autouse=True) + def setup_method(self): + """Should be implemented in child classes to setup specific storage.""" + self.storage = BaseStorage() + + def test_save(self): + """Test saving data.""" + self.storage.save(get_example_filename(), get_example_data()) + + def test_load_once(self): + """Test loading data once.""" + assert self.storage.load_once(get_example_filename()) == get_example_data() + + def test_load_stream(self): + """Test loading data as a stream.""" + generator = self.storage.load_stream(get_example_filename()) + assert isinstance(generator, Generator) + assert next(generator) == get_example_data() + + def test_download(self): + """Test downloading data.""" + self.storage.download(get_example_filename(), get_example_filepath()) + + def test_exists(self): + """Test checking if a file exists.""" + assert self.storage.exists(get_example_filename()) + + def test_delete(self): + """Test deleting a file.""" + self.storage.delete(get_example_filename()) diff --git a/api/tests/unit_tests/oss/__mock/local.py b/api/tests/unit_tests/oss/__mock/local.py new file mode 100644 index 0000000000..95cc06958c --- /dev/null +++ b/api/tests/unit_tests/oss/__mock/local.py @@ -0,0 +1,57 @@ +import os +import shutil +from pathlib import Path +from unittest.mock import MagicMock, mock_open, patch + +import pytest +from _pytest.monkeypatch import MonkeyPatch + +from tests.unit_tests.oss.__mock.base import ( + get_example_data, + get_example_filename, + get_example_filepath, + get_example_folder, +) + + +class MockLocalFSClass: + def write_bytes(self, data): + assert data == get_example_data() + + def read_bytes(self): + return get_example_data() + + @staticmethod + def copyfile(src, dst): + assert src == os.path.join(get_example_folder(), get_example_filename()) + assert dst == get_example_filepath() + + @staticmethod + def exists(path): + assert path == os.path.join(get_example_folder(), get_example_filename()) + return True + + @staticmethod + def remove(path): + assert path == os.path.join(get_example_folder(), get_example_filename()) + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_local_fs_mock(monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(Path, "write_bytes", MockLocalFSClass.write_bytes) + monkeypatch.setattr(Path, "read_bytes", MockLocalFSClass.read_bytes) + monkeypatch.setattr(shutil, "copyfile", MockLocalFSClass.copyfile) + monkeypatch.setattr(os.path, "exists", MockLocalFSClass.exists) + monkeypatch.setattr(os, "remove", MockLocalFSClass.remove) + + os.makedirs = MagicMock() + + with patch("builtins.open", mock_open(read_data=get_example_data())): + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/unit_tests/oss/__mock/tencent_cos.py b/api/tests/unit_tests/oss/__mock/tencent_cos.py new file mode 100644 index 0000000000..5189b68e87 --- /dev/null +++ b/api/tests/unit_tests/oss/__mock/tencent_cos.py @@ -0,0 +1,81 @@ +import os +from unittest.mock import MagicMock + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from qcloud_cos import CosS3Client +from qcloud_cos.streambody import StreamBody + +from tests.unit_tests.oss.__mock.base import ( + get_example_bucket, + get_example_data, + get_example_filename, + get_example_filepath, +) + + +class MockTencentCosClass: + def __init__(self, conf, retry=1, session=None): + self.bucket_name = get_example_bucket() + self.key = get_example_filename() + self.content = get_example_data() + self.filepath = get_example_filepath() + self.resp = { + "ETag": "ee8de918d05640145b18f70f4c3aa602", + "Server": "tencent-cos", + "x-cos-hash-crc64ecma": 16749565679157681890, + "x-cos-request-id": "NWU5MDNkYzlfNjRiODJhMDlfMzFmYzhfMTFm****", + } + + def put_object(self, Bucket, Body, Key, EnableMD5=False, **kwargs): # noqa: N803 + assert Bucket == self.bucket_name + assert Key == self.key + assert Body == self.content + return self.resp + + def get_object(self, Bucket, Key, KeySimplifyCheck=True, **kwargs): # noqa: N803 + assert Bucket == self.bucket_name + assert Key == self.key + + mock_stream_body = MagicMock(StreamBody) + mock_raw_stream = MagicMock() + mock_stream_body.get_raw_stream.return_value = mock_raw_stream + mock_raw_stream.read.return_value = self.content + + mock_stream_body.get_stream_to_file = MagicMock() + + def chunk_generator(chunk_size=2): + for i in range(0, len(self.content), chunk_size): + yield self.content[i : i + chunk_size] + + mock_stream_body.get_stream.return_value = chunk_generator(chunk_size=4096) + return {"Body": mock_stream_body} + + def object_exists(self, Bucket, Key): # noqa: N803 + assert Bucket == self.bucket_name + assert Key == self.key + return True + + def delete_object(self, Bucket, Key, **kwargs): # noqa: N803 + assert Bucket == self.bucket_name + assert Key == self.key + self.resp.update({"x-cos-delete-marker": True}) + return self.resp + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_tencent_cos_mock(monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(CosS3Client, "__init__", MockTencentCosClass.__init__) + monkeypatch.setattr(CosS3Client, "put_object", MockTencentCosClass.put_object) + monkeypatch.setattr(CosS3Client, "get_object", MockTencentCosClass.get_object) + monkeypatch.setattr(CosS3Client, "object_exists", MockTencentCosClass.object_exists) + monkeypatch.setattr(CosS3Client, "delete_object", MockTencentCosClass.delete_object) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py index 241764c521..1194a03258 100644 --- a/api/tests/unit_tests/oss/__mock/volcengine_tos.py +++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py @@ -1,5 +1,4 @@ import os -from typing import Union from unittest.mock import MagicMock import pytest @@ -7,28 +6,19 @@ from _pytest.monkeypatch import MonkeyPatch from tos import TosClientV2 from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput +from tests.unit_tests.oss.__mock.base import ( + get_example_bucket, + get_example_data, + get_example_filename, + get_example_filepath, +) + class AttrDict(dict): def __getattr__(self, item): return self.get(item) -def get_example_bucket() -> str: - return "dify" - - -def get_example_filename() -> str: - return "test.txt" - - -def get_example_data() -> bytes: - return b"test" - - -def get_example_filepath() -> str: - return "/test" - - class MockVolcengineTosClass: def __init__(self, ak="", sk="", endpoint="", region=""): self.bucket_name = get_example_bucket() diff --git a/api/tests/unit_tests/oss/local/test_local_fs.py b/api/tests/unit_tests/oss/local/test_local_fs.py new file mode 100644 index 0000000000..03ce7d2450 --- /dev/null +++ b/api/tests/unit_tests/oss/local/test_local_fs.py @@ -0,0 +1,18 @@ +from collections.abc import Generator + +import pytest + +from extensions.storage.local_fs_storage import LocalFsStorage +from tests.unit_tests.oss.__mock.base import ( + BaseStorageTest, + get_example_folder, +) +from tests.unit_tests.oss.__mock.local import setup_local_fs_mock + + +class TestLocalFS(BaseStorageTest): + @pytest.fixture(autouse=True) + def setup_method(self, setup_local_fs_mock): + """Executed before each test method.""" + self.storage = LocalFsStorage() + self.storage.folder = get_example_folder() diff --git a/api/tests/unit_tests/oss/tencent_cos/__init__.py b/api/tests/unit_tests/oss/tencent_cos/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py new file mode 100644 index 0000000000..303f0493bd --- /dev/null +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -0,0 +1,20 @@ +from unittest.mock import patch + +import pytest +from qcloud_cos import CosConfig + +from extensions.storage.tencent_cos_storage import TencentCosStorage +from tests.unit_tests.oss.__mock.base import ( + BaseStorageTest, + get_example_bucket, +) +from tests.unit_tests.oss.__mock.tencent_cos import setup_tencent_cos_mock + + +class TestTencentCos(BaseStorageTest): + @pytest.fixture(autouse=True) + def setup_method(self, setup_tencent_cos_mock): + """Executed before each test method.""" + with patch.object(CosConfig, "__init__", return_value=None): + self.storage = TencentCosStorage() + self.storage.bucket_name = get_example_bucket() diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 545d18044d..5afbc9e8b4 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -1,30 +1,18 @@ -from collections.abc import Generator - -from flask import Flask +import pytest from tos import TosClientV2 -from tos.clientv2 import GetObjectOutput, HeadObjectOutput, PutObjectOutput from extensions.storage.volcengine_tos_storage import VolcengineTosStorage -from tests.unit_tests.oss.__mock.volcengine_tos import ( +from tests.unit_tests.oss.__mock.base import ( + BaseStorageTest, get_example_bucket, - get_example_data, - get_example_filename, - get_example_filepath, - setup_volcengine_tos_mock, ) +from tests.unit_tests.oss.__mock.volcengine_tos import setup_volcengine_tos_mock -class VolcengineTosTest: - _instance = None - - def __new__(cls): - if cls._instance == None: - cls._instance = object.__new__(cls) - return cls._instance - else: - return cls._instance - - def __init__(self): +class TestVolcengineTos(BaseStorageTest): + @pytest.fixture(autouse=True) + def setup_method(self, setup_volcengine_tos_mock): + """Executed before each test method.""" self.storage = VolcengineTosStorage() self.storage.bucket_name = get_example_bucket() self.storage.client = TosClientV2( @@ -33,35 +21,3 @@ class VolcengineTosTest: endpoint="https://xxx.volces.com", region="cn-beijing", ) - - -def test_save(setup_volcengine_tos_mock): - volc_tos = VolcengineTosTest() - volc_tos.storage.save(get_example_filename(), get_example_data()) - - -def test_load_once(setup_volcengine_tos_mock): - volc_tos = VolcengineTosTest() - assert volc_tos.storage.load_once(get_example_filename()) == get_example_data() - - -def test_load_stream(setup_volcengine_tos_mock): - volc_tos = VolcengineTosTest() - generator = volc_tos.storage.load_stream(get_example_filename()) - assert isinstance(generator, Generator) - assert next(generator) == get_example_data() - - -def test_download(setup_volcengine_tos_mock): - volc_tos = VolcengineTosTest() - volc_tos.storage.download(get_example_filename(), get_example_filepath()) - - -def test_exists(setup_volcengine_tos_mock): - volc_tos = VolcengineTosTest() - assert volc_tos.storage.exists(get_example_filename()) - - -def test_delete(setup_volcengine_tos_mock): - volc_tos = VolcengineTosTest() - volc_tos.storage.delete(get_example_filename()) diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index 579da6a30e..418a129693 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -11,4 +11,5 @@ pytest api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/vikingdb \ api/tests/integration_tests/vdb/baidu \ api/tests/integration_tests/vdb/tcvectordb \ - api/tests/integration_tests/vdb/upstash \ No newline at end of file + api/tests/integration_tests/vdb/upstash \ + api/tests/integration_tests/vdb/couchbase \ diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 17b788ff81..e3f1c3b761 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.10.1 + image: langgenius/dify-api:0.10.2 restart: always environment: # Startup mode, 'api' starts the API server. @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.10.1 + image: langgenius/dify-api:0.10.2 restart: always environment: CONSOLE_WEB_URL: '' @@ -396,7 +396,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.10.1 + image: langgenius/dify-web:0.10.2 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index dbdc943b06..c506a9d92e 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -222,6 +222,7 @@ REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false +REDIS_DB=0 # Whether to use Redis Sentinel mode. # If set to true, the application will automatically discover and connect to the master node through Sentinel. @@ -273,6 +274,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=* # Supported values are `local` , `s3` , `azure-blob` , `google-storage`, `tencent-cos`, `huawei-obs`, `volcengine-tos`, `baidu-obs`, `supabase` # Default: `local` STORAGE_TYPE=local +STORAGE_LOCAL_PATH=storage # S3 Configuration # Whether to use AWS managed IAM roles for authenticating with the S3 service. @@ -373,7 +375,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `vikingdb`. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`. VECTOR_STORE=weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. @@ -412,6 +414,14 @@ MYSCALE_PASSWORD= MYSCALE_DATABASE=dify MYSCALE_FTS_PARAMS= +# Couchbase configurations, only available when VECTOR_STORE is `couchbase` +# The connection string must include hostname defined in the docker-compose file (couchbase-server in this case) +COUCHBASE_CONNECTION_STRING=couchbase://couchbase-server +COUCHBASE_USER=Administrator +COUCHBASE_PASSWORD=password +COUCHBASE_BUCKET_NAME=Embeddings +COUCHBASE_SCOPE_NAME=_default + # pgvector configurations, only available when VECTOR_STORE is `pgvector` PGVECTOR_HOST=pgvector PGVECTOR_PORT=5432 @@ -591,6 +601,7 @@ MAIL_DEFAULT_SEND_FROM= # API-Key for the Resend email provider, used when MAIL_TYPE is `resend`. RESEND_API_KEY=your-resend-api-key +RESEND_API_URL=https://api.resend.com # SMTP server configuration, used when MAIL_TYPE is `smtp` SMTP_SERVER= @@ -630,6 +641,7 @@ CODE_MAX_NUMBER_ARRAY_LENGTH=1000 WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 +MAX_VARIABLE_SIZE=204800 # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 diff --git a/docker/couchbase-server/Dockerfile b/docker/couchbase-server/Dockerfile new file mode 100644 index 0000000000..bd8af64150 --- /dev/null +++ b/docker/couchbase-server/Dockerfile @@ -0,0 +1,4 @@ +FROM couchbase/server:latest AS stage_base +# FROM couchbase:latest AS stage_base +COPY init-cbserver.sh /opt/couchbase/init/ +RUN chmod +x /opt/couchbase/init/init-cbserver.sh \ No newline at end of file diff --git a/docker/couchbase-server/init-cbserver.sh b/docker/couchbase-server/init-cbserver.sh new file mode 100755 index 0000000000..e66bc18530 --- /dev/null +++ b/docker/couchbase-server/init-cbserver.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# used to start couchbase server - can't get around this as docker compose only allows you to start one command - so we have to start couchbase like the standard couchbase Dockerfile would +# https://github.com/couchbase/docker/blob/master/enterprise/couchbase-server/7.2.0/Dockerfile#L88 + +/entrypoint.sh couchbase-server & + +# track if setup is complete so we don't try to setup again +FILE=/opt/couchbase/init/setupComplete.txt + +if ! [ -f "$FILE" ]; then + # used to automatically create the cluster based on environment variables + # https://docs.couchbase.com/server/current/cli/cbcli/couchbase-cli-cluster-init.html + + echo $COUCHBASE_ADMINISTRATOR_USERNAME ":" $COUCHBASE_ADMINISTRATOR_PASSWORD + + sleep 20s + /opt/couchbase/bin/couchbase-cli cluster-init -c 127.0.0.1 \ + --cluster-username $COUCHBASE_ADMINISTRATOR_USERNAME \ + --cluster-password $COUCHBASE_ADMINISTRATOR_PASSWORD \ + --services data,index,query,fts \ + --cluster-ramsize $COUCHBASE_RAM_SIZE \ + --cluster-index-ramsize $COUCHBASE_INDEX_RAM_SIZE \ + --cluster-eventing-ramsize $COUCHBASE_EVENTING_RAM_SIZE \ + --cluster-fts-ramsize $COUCHBASE_FTS_RAM_SIZE \ + --index-storage-setting default + + sleep 2s + + # used to auto create the bucket based on environment variables + # https://docs.couchbase.com/server/current/cli/cbcli/couchbase-cli-bucket-create.html + + /opt/couchbase/bin/couchbase-cli bucket-create -c localhost:8091 \ + --username $COUCHBASE_ADMINISTRATOR_USERNAME \ + --password $COUCHBASE_ADMINISTRATOR_PASSWORD \ + --bucket $COUCHBASE_BUCKET \ + --bucket-ramsize $COUCHBASE_BUCKET_RAMSIZE \ + --bucket-type couchbase + + # create file so we know that the cluster is setup and don't run the setup again + touch $FILE +fi + # docker compose will stop the container from running unless we do this + # known issue and workaround + tail -f /dev/null diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 0c9edd2b55..31624285b1 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -16,7 +16,7 @@ services: -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' volumes: - - ./volumes/db/data:/var/lib/postgresql/data + - ${PGDATA_HOST_VOLUME:-./volumes/db/data}:/var/lib/postgresql/data ports: - "${EXPOSE_POSTGRES_PORT:-5432}:5432" healthcheck: @@ -31,7 +31,7 @@ services: restart: always volumes: # Mount the redis data directory to the container. - - ./volumes/redis/data:/data + - ${REDIS_HOST_VOLUME:-./volumes/redis/data}:/data # Set the redis password when startup redis server. command: redis-server --requirepass difyai123456 ports: @@ -94,7 +94,7 @@ services: restart: always volumes: # Mount the Weaviate data directory to the container. - - ./volumes/weaviate:/var/lib/weaviate + - ${WEAVIATE_HOST_VOLUME:-./volumes/weaviate}:/var/lib/weaviate env_file: - ./middleware.env environment: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 86dc866773..d5d087ab9a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -43,17 +43,17 @@ x-shared-env: &shared-api-worker-env REDIS_USERNAME: ${REDIS_USERNAME:-} REDIS_PASSWORD: ${REDIS_PASSWORD:-difyai123456} REDIS_USE_SSL: ${REDIS_USE_SSL:-false} - REDIS_DB: 0 + REDIS_DB: ${REDIS_DB:-0} REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} REDIS_SENTINELS: ${REDIS_SENTINELS:-} REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-} REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-} REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-} - ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1} REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false} REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} + ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} @@ -62,7 +62,7 @@ x-shared-env: &shared-api-worker-env WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*} CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} STORAGE_TYPE: ${STORAGE_TYPE:-local} - STORAGE_LOCAL_PATH: storage + STORAGE_LOCAL_PATH: ${STORAGE_LOCAL_PATH:-storage} S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false} S3_ENDPOINT: ${S3_ENDPOINT:-} S3_BUCKET_NAME: ${S3_BUCKET_NAME:-} @@ -113,6 +113,11 @@ x-shared-env: &shared-api-worker-env QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20} QDRANT_GRPC_ENABLED: ${QDRANT_GRPC_ENABLED:-false} QDRANT_GRPC_PORT: ${QDRANT_GRPC_PORT:-6334} + COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-'couchbase-server'} + COUCHBASE_USER: ${COUCHBASE_USER:-Administrator} + COUCHBASE_PASSWORD: ${COUCHBASE_PASSWORD:-password} + COUCHBASE_BUCKET_NAME: ${COUCHBASE_BUCKET_NAME:-Embeddings} + COUCHBASE_SCOPE_NAME: ${COUCHBASE_SCOPE_NAME:-_default} MILVUS_URI: ${MILVUS_URI:-http://127.0.0.1:19530} MILVUS_TOKEN: ${MILVUS_TOKEN:-} MILVUS_USER: ${MILVUS_USER:-root} @@ -211,7 +216,7 @@ x-shared-env: &shared-api-worker-env SMTP_USE_TLS: ${SMTP_USE_TLS:-true} SMTP_OPPORTUNISTIC_TLS: ${SMTP_OPPORTUNISTIC_TLS:-false} RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} - RESEND_API_URL: https://api.resend.com + RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com} INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-1000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} @@ -240,11 +245,12 @@ x-shared-env: &shared-api-worker-env POSITION_PROVIDER_PINS: ${POSITION_PROVIDER_PINS:-} POSITION_PROVIDER_INCLUDES: ${POSITION_PROVIDER_INCLUDES:-} POSITION_PROVIDER_EXCLUDES: ${POSITION_PROVIDER_EXCLUDES:-} + MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} services: # API service api: - image: langgenius/dify-api:0.10.1 + image: langgenius/dify-api:0.10.2 restart: always environment: # Use the shared environment variables. @@ -264,7 +270,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.10.1 + image: langgenius/dify-api:0.10.2 restart: always environment: # Use the shared environment variables. @@ -283,7 +289,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.10.1 + image: langgenius/dify-web:0.10.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -477,6 +483,39 @@ services: environment: QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456} + # The Couchbase vector store. + couchbase-server: + build: ./couchbase-server + profiles: + - couchbase + restart: always + environment: + - CLUSTER_NAME=dify_search + - COUCHBASE_ADMINISTRATOR_USERNAME=${COUCHBASE_USER:-Administrator} + - COUCHBASE_ADMINISTRATOR_PASSWORD=${COUCHBASE_PASSWORD:-password} + - COUCHBASE_BUCKET=${COUCHBASE_BUCKET_NAME:-Embeddings} + - COUCHBASE_BUCKET_RAMSIZE=512 + - COUCHBASE_RAM_SIZE=2048 + - COUCHBASE_EVENTING_RAM_SIZE=512 + - COUCHBASE_INDEX_RAM_SIZE=512 + - COUCHBASE_FTS_RAM_SIZE=1024 + hostname: couchbase-server + container_name: couchbase-server + working_dir: /opt/couchbase + stdin_open: true + tty: true + entrypoint: [""] + command: sh -c "/opt/couchbase/init/init-cbserver.sh" + volumes: + - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data + healthcheck: + # ensure bucket was created before proceeding + test: [ "CMD-SHELL", "curl -s -f -u Administrator:password http://localhost:8091/pools/default/buckets | grep -q '\\[{' || exit 1" ] + interval: 10s + retries: 10 + start_period: 30s + timeout: 10s + # The pgvector vector database. pgvector: image: pgvector/pgvector:pg16 diff --git a/docker/middleware.env.example b/docker/middleware.env.example index 04d0fb5ed3..17ac819527 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -8,6 +8,7 @@ POSTGRES_PASSWORD=difyai123456 POSTGRES_DB=dify # postgres data directory PGDATA=/var/lib/postgresql/data/pgdata +PGDATA_HOST_VOLUME=./volumes/db/data # Maximum number of connections to the database # Default is 100 @@ -39,6 +40,11 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB # Reference: https://www.postgresql.org/docs/current/runtime-config-query.html#GUC-EFFECTIVE-CACHE-SIZE POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB +# ----------------------------- +# Environment Variables for redis Service +REDIS_HOST_VOLUME=./volumes/redis/data +# ----------------------------- + # ------------------------------ # Environment Variables for sandbox Service SANDBOX_API_KEY=dify-sandbox @@ -70,6 +76,7 @@ WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih WEAVIATE_AUTHENTICATION_APIKEY_USERS=hello@dify.ai WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED=true WEAVIATE_AUTHORIZATION_ADMINLIST_USERS=hello@dify.ai +WEAVIATE_HOST_VOLUME=./volumes/weaviate # ------------------------------ # Docker Compose Service Expose Host Port Configurations diff --git a/web/__mocks__/mime.js b/web/__mocks__/mime.js new file mode 100644 index 0000000000..e69de29bb2 diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index b846f6d9fb..e264fd707e 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -236,12 +236,31 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Knowledge name + + Knowledge description (optional) + + + Index Technique (optional) + - high_quality high_quality + - economy economy + Permission - only_me Only me - all_team_members All team members - partial_members Partial members + + Provider (optional, default: vendor) + - vendor vendor + - external external knowledge + + + External Knowledge api id (optional) + + + External Knowledge id (optional) + diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index ece4d3b771..5d52664db4 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -234,14 +234,33 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - 知识库名称 + 知识库名称(必填) + + + 知识库描述(选填) + + + 索引模式(选填,建议填写) + - high_quality 高质量 + - economy 经济 - 权限 + 权限(选填,默认only_me) - only_me 仅自己 - all_team_members 所有团队成员 - partial_members 部分团队成员 + + provider,(选填,默认 vendor) + - vendor 上传文件 + - external 外部知识库 + + + 外部知识库 API_ID(选填) + + + 外部知识库 ID(选填) + diff --git a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx index b2d45d2733..b63e3e2693 100644 --- a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx +++ b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx @@ -7,8 +7,7 @@ import ConfigPrompt from '../../config-prompt' import { languageMap } from '../../../../workflow/nodes/_base/components/editor/code-editor/index' import { generateRuleCode } from '@/service/debug' import type { CodeGenRes } from '@/service/debug' -import { ModelModeType } from '@/types/app' -import type { AppType, Model } from '@/types/app' +import { type AppType, type Model, ModelModeType } from '@/types/app' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' import { Generator } from '@/app/components/base/icons/src/vender/other' @@ -16,6 +15,10 @@ import Toast from '@/app/components/base/toast' import Loading from '@/app/components/base/loading' import Confirm from '@/app/components/base/confirm' import type { CodeLanguage } from '@/app/components/workflow/nodes/code/types' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' +import ModelName from '@/app/components/header/account-setting/model-provider-page/model-name' export type IGetCodeGeneratorResProps = { mode: AppType isShow: boolean @@ -31,9 +34,12 @@ export const GetCodeGeneratorResModal: FC = ( codeLanguages, onClose, onFinished, - }, ) => { + const { + currentProvider, + currentModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) const { t } = useTranslation() const [instruction, setInstruction] = React.useState('') const [isLoading, { setTrue: setLoadingTrue, setFalse: setLoadingFalse }] = useBoolean(false) @@ -51,9 +57,10 @@ export const GetCodeGeneratorResModal: FC = ( return true } const model: Model = { - provider: 'openai', - name: 'gpt-4o-mini', + provider: currentProvider?.provider || '', + name: currentModel?.model || '', mode: ModelModeType.chat, + // This is a fixed parameter completion_params: { temperature: 0.7, max_tokens: 0, @@ -112,6 +119,19 @@ export const GetCodeGeneratorResModal: FC = (

{t('appDebug.codegen.title')}
{t('appDebug.codegen.description')}
+
+ + +
{t('appDebug.codegen.instruction')}
diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index 71e441d415..480bd782ae 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -1,6 +1,5 @@ 'use client' import type { FC } from 'react' -import useSWR from 'swr' import { useTranslation } from 'react-i18next' import React, { useCallback, useEffect, useRef, useState } from 'react' import produce, { setAutoFreeze } from 'immer' @@ -39,7 +38,6 @@ import { promptVariablesToUserInputsForm } from '@/utils/model-config' import TextGeneration from '@/app/components/app/text-generate/item' import { IS_CE_EDITION } from '@/config' import type { Inputs } from '@/models/debug' -import { fetchFileUploadConfig } from '@/service/common' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import type { ModelParameterModalProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' @@ -94,7 +92,6 @@ const Debug: FC = ({ } = useContext(ConfigContext) const { eventEmitter } = useEventEmitterContextContext() const { data: text2speechDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding) - const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) useEffect(() => { setAutoFreeze(false) return () => { @@ -452,7 +449,7 @@ const Debug: FC = ({ visionConfig={{ ...features.file! as VisionSettings, transfer_methods: features.file!.allowed_file_upload_methods || [], - image_file_size_limit: fileUploadConfigResponse?.image_file_size_limit, + image_file_size_limit: features.file?.fileUploadConfig?.image_file_size_limit, }} onVisionFilesChange={setCompletionFiles} /> diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 434b54ab91..12ee7d75ad 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -1,6 +1,7 @@ 'use client' import type { FC } from 'react' import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import useSWR from 'swr' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { usePathname } from 'next/navigation' @@ -69,6 +70,7 @@ import type { Features as FeaturesData, FileUpload } from '@/app/components/base import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' import { SupportUploadFileTypes } from '@/app/components/workflow/types' import NewFeaturePanel from '@/app/components/base/features/new-feature-panel' +import { fetchFileUploadConfig } from '@/service/common' type PublishConfig = { modelConfig: ModelConfig @@ -84,6 +86,8 @@ const Configuration: FC = () => { showAppConfigureFeaturesModal: state.showAppConfigureFeaturesModal, setShowAppConfigureFeaturesModal: state.setShowAppConfigureFeaturesModal, }))) + const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) + const latestPublishedAt = useMemo(() => appDetail?.model_config.updated_at, [appDetail]) const [formattingChanged, setFormattingChanged] = useState(false) const { setShowAccountSettingModal } = useModalContext() @@ -462,12 +466,13 @@ const Configuration: FC = () => { allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3, + fileUploadConfig: fileUploadConfigResponse, } as FileUpload, suggested: modelConfig.suggested_questions_after_answer || { enabled: false }, citation: modelConfig.retriever_resource || { enabled: false }, annotationReply: modelConfig.annotation_reply || { enabled: false }, } - }, [modelConfig]) + }, [fileUploadConfigResponse, modelConfig]) const handleFeaturesChange = useCallback((flag: any) => { setShowAppConfigureFeaturesModal(true) if (flag) @@ -684,6 +689,9 @@ const Configuration: FC = () => { }, })) + const fileUpload = { ...features?.file } + delete fileUpload?.fileUploadConfig + // new model config data struct const data: BackendModelConfig = { // Simple Mode prompt @@ -700,7 +708,7 @@ const Configuration: FC = () => { sensitive_word_avoidance: features?.moderation as any, speech_to_text: features?.speech2text as any, text_to_speech: features?.text2speech as any, - file_upload: features?.file as any, + file_upload: fileUpload as any, suggested_questions_after_answer: features?.suggested as any, retriever_resource: features?.citation as any, agent_mode: { diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 6d643a01a3..22585aa678 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -16,7 +16,7 @@ import timezone from 'dayjs/plugin/timezone' import { createContext, useContext } from 'use-context-selector' import { useShallow } from 'zustand/react/shallow' import { useTranslation } from 'react-i18next' -import { UUID_NIL } from '../../base/chat/constants' +import type { ChatItemInTree } from '../../base/chat/types' import VarPanel from './var-panel' import cn from '@/utils/classnames' import type { FeedbackFunc, FeedbackType, IChatItem, SubmitAnnotationFunc } from '@/app/components/base/chat/chat/type' @@ -41,6 +41,7 @@ import { useAppContext } from '@/context/app-context' import useTimestamp from '@/hooks/use-timestamp' import Tooltip from '@/app/components/base/tooltip' import { CopyIcon } from '@/app/components/base/copy-icon' +import { buildChatItemTree, getThreadMessages } from '@/app/components/base/chat/utils' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' dayjs.extend(utc) @@ -82,94 +83,76 @@ const PARAM_MAP = { frequency_penalty: 'Frequency Penalty', } -function appendQAToChatList(newChatList: IChatItem[], item: any, conversationId: string, timezone: string, format: string) { - const answerFiles = item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [] - newChatList.push({ - id: item.id, - content: item.answer, - agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files), - feedback: item.feedbacks.find((item: any) => item.from_source === 'user'), // user feedback - adminFeedback: item.feedbacks.find((item: any) => item.from_source === 'admin'), // admin feedback - feedbackDisabled: false, - isAnswer: true, - message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id }))), - log: [ - ...item.message, - ...(item.message[item.message.length - 1]?.role !== 'assistant' - ? [ - { - role: 'assistant', - text: item.answer, - files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], - }, - ] - : []), - ], - workflow_run_id: item.workflow_run_id, - conversationId, - input: { - inputs: item.inputs, - query: item.query, - }, - more: { - time: dayjs.unix(item.created_at).tz(timezone).format(format), - tokens: item.answer_tokens + item.message_tokens, - latency: item.provider_response_latency.toFixed(2), - }, - citation: item.metadata?.retriever_resources, - annotation: (() => { - if (item.annotation_hit_history) { - return { - id: item.annotation_hit_history.annotation_id, - authorName: item.annotation_hit_history.annotation_create_account?.name || 'N/A', - created_at: item.annotation_hit_history.created_at, - } - } - - if (item.annotation) { - return { - id: item.annotation.id, - authorName: item.annotation.account.name, - logAnnotation: item.annotation, - created_at: 0, - } - } - - return undefined - })(), - parentMessageId: `question-${item.id}`, - }) - const questionFiles = item.message_files?.filter((file: any) => file.belongs_to === 'user') || [] - newChatList.push({ - id: `question-${item.id}`, - content: item.inputs.query || item.inputs.default_input || item.query, // text generation: item.inputs.query; chat: item.query - isAnswer: false, - message_files: getProcessedFilesFromResponse(questionFiles.map((item: any) => ({ ...item, related_id: item.id }))), - parentMessageId: item.parent_message_id || undefined, - }) -} - const getFormattedChatList = (messages: ChatMessage[], conversationId: string, timezone: string, format: string) => { const newChatList: IChatItem[] = [] - let nextMessageId = null - for (const item of messages) { - if (!item.parent_message_id) { - appendQAToChatList(newChatList, item, conversationId, timezone, format) - break - } + messages.forEach((item: ChatMessage) => { + const questionFiles = item.message_files?.filter((file: any) => file.belongs_to === 'user') || [] + newChatList.push({ + id: `question-${item.id}`, + content: item.inputs.query || item.inputs.default_input || item.query, // text generation: item.inputs.query; chat: item.query + isAnswer: false, + message_files: getProcessedFilesFromResponse(questionFiles.map((item: any) => ({ ...item, related_id: item.id }))), + parentMessageId: item.parent_message_id || undefined, + }) - if (!nextMessageId) { - appendQAToChatList(newChatList, item, conversationId, timezone, format) - nextMessageId = item.parent_message_id - } - else { - if (item.id === nextMessageId || nextMessageId === UUID_NIL) { - appendQAToChatList(newChatList, item, conversationId, timezone, format) - nextMessageId = item.parent_message_id - } - } - } - return newChatList.reverse() + const answerFiles = item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [] + newChatList.push({ + id: item.id, + content: item.answer, + agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files), + feedback: item.feedbacks.find(item => item.from_source === 'user'), // user feedback + adminFeedback: item.feedbacks.find(item => item.from_source === 'admin'), // admin feedback + feedbackDisabled: false, + isAnswer: true, + message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id }))), + log: [ + ...item.message, + ...(item.message[item.message.length - 1]?.role !== 'assistant' + ? [ + { + role: 'assistant', + text: item.answer, + files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], + }, + ] + : []), + ] as IChatItem['log'], + workflow_run_id: item.workflow_run_id, + conversationId, + input: { + inputs: item.inputs, + query: item.query, + }, + more: { + time: dayjs.unix(item.created_at).tz(timezone).format(format), + tokens: item.answer_tokens + item.message_tokens, + latency: item.provider_response_latency.toFixed(2), + }, + citation: item.metadata?.retriever_resources, + annotation: (() => { + if (item.annotation_hit_history) { + return { + id: item.annotation_hit_history.annotation_id, + authorName: item.annotation_hit_history.annotation_create_account?.name || 'N/A', + created_at: item.annotation_hit_history.created_at, + } + } + + if (item.annotation) { + return { + id: item.annotation.id, + authorName: item.annotation.account.name, + logAnnotation: item.annotation, + created_at: 0, + } + } + + return undefined + })(), + parentMessageId: `question-${item.id}`, + }) + }) + return newChatList } // const displayedParams = CompletionParams.slice(0, -2) @@ -193,50 +176,66 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { currentLogModalActiveTab: state.currentLogModalActiveTab, }))) const { t } = useTranslation() - const [items, setItems] = React.useState([]) - const fetchedMessages = useRef([]) const [hasMore, setHasMore] = useState(true) const [varValues, setVarValues] = useState>({}) - const fetchData = async () => { + + const [allChatItems, setAllChatItems] = useState([]) + const [chatItemTree, setChatItemTree] = useState([]) + const [threadChatItems, setThreadChatItems] = useState([]) + + const fetchData = useCallback(async () => { try { if (!hasMore) return + const params: ChatMessagesRequest = { conversation_id: detail.id, limit: 10, } - if (items?.[0]?.id) - params.first_id = items?.[0]?.id.replace('question-', '') - + if (allChatItems.at(-1)?.id) + params.first_id = allChatItems.at(-1)?.id.replace('question-', '') const messageRes = await fetchChatMessages({ url: `/apps/${appDetail?.id}/chat-messages`, params, }) if (messageRes.data.length > 0) { - const varValues = messageRes.data[0].inputs + const varValues = messageRes.data.at(-1)!.inputs setVarValues(varValues) } - fetchedMessages.current = [...fetchedMessages.current, ...messageRes.data] - const newItems = getFormattedChatList(fetchedMessages.current, detail.id, timezone!, t('appLog.dateTimeFormat') as string) + setHasMore(messageRes.has_more) + + const newAllChatItems = [ + ...getFormattedChatList(messageRes.data, detail.id, timezone!, t('appLog.dateTimeFormat') as string), + ...allChatItems, + ] + setAllChatItems(newAllChatItems) + + let tree = buildChatItemTree(newAllChatItems) if (messageRes.has_more === false && detail?.model_config?.configs?.introduction) { - newItems.unshift({ + tree = [{ id: 'introduction', isAnswer: true, isOpeningStatement: true, content: detail?.model_config?.configs?.introduction ?? 'hello', feedbackDisabled: true, - }) + children: tree, + }] } - setItems(newItems) - setHasMore(messageRes.has_more) + setChatItemTree(tree) + + setThreadChatItems(getThreadMessages(tree, newAllChatItems.at(-1)?.id)) } catch (err) { console.error(err) } - } + }, [allChatItems, detail.id, hasMore, timezone, t, appDetail, detail?.model_config?.configs?.introduction]) + + const switchSibling = useCallback((siblingMessageId: string) => { + setThreadChatItems(getThreadMessages(chatItemTree, siblingMessageId)) + }, [chatItemTree]) const handleAnnotationEdited = useCallback((query: string, answer: string, index: number) => { - setItems(items.map((item, i) => { + setAllChatItems(allChatItems.map((item, i) => { if (i === index - 1) { return { ...item, @@ -257,9 +256,9 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { } return item })) - }, [items]) + }, [allChatItems]) const handleAnnotationAdded = useCallback((annotationId: string, authorName: string, query: string, answer: string, index: number) => { - setItems(items.map((item, i) => { + setAllChatItems(allChatItems.map((item, i) => { if (i === index - 1) { return { ...item, @@ -287,9 +286,9 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { } return item })) - }, [items]) + }, [allChatItems]) const handleAnnotationRemoved = useCallback((index: number) => { - setItems(items.map((item, i) => { + setAllChatItems(allChatItems.map((item, i) => { if (i === index) { return { ...item, @@ -299,7 +298,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { } return item })) - }, [items]) + }, [allChatItems]) const fetchInitiated = useRef(false) @@ -464,7 +463,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { siteInfo={null} />
- : (items.length < 8 && !hasMore) + : threadChatItems.length < 8 ?
:
{t('appLog.detail.loading')}...
} @@ -532,7 +532,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { }, supportFeedback: true, } as any} - chatList={items} + chatList={threadChatItems} onAnnotationAdded={handleAnnotationAdded} onAnnotationEdited={handleAnnotationEdited} onAnnotationRemoved={handleAnnotationRemoved} @@ -541,6 +541,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { showPromptLog hideProcessDetail chatContainerInnerClassName='px-6' + switchSibling={switchSibling} />
diff --git a/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap b/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap new file mode 100644 index 0000000000..070975bfa7 --- /dev/null +++ b/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap @@ -0,0 +1,2281 @@ +// Jest Snapshot v1, https://goo.gl/fbAQLP + +exports[`build chat item tree and get thread messages should get thread messages from tree6, using specified message as target 1`] = ` +[ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105791, + "files": [], + "id": "f9d7ff7c-3a3b-4d9a-a289-657817f4caff", + "message_id": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "observation": "", + "position": 1, + "thought": "Sure, I'll play! My number is 57. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105795, + "files": [], + "id": "f61a3fce-37ac-4f9d-9935-95f97e598dfe", + "message_id": "73bbad14-d915-499d-87bf-0df14d40779d", + "observation": "", + "position": 1, + "thought": "I choose 83. What's your next number?", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105799, + "files": [], + "id": "9730d587-9268-4683-9dd9-91a1cab9510b", + "message_id": "4c5d0841-1206-463e-95d8-71f812877658", + "observation": "", + "position": 1, + "thought": "I'll go with 112. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "I'll go with 112. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "4c5d0841-1206-463e-95d8-71f812877658", + "input": { + "inputs": {}, + "query": "99", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + { + "files": [], + "role": "user", + "text": "99", + }, + { + "files": [], + "role": "assistant", + "text": "I'll go with 112. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.49", + "time": "09/11/2024 09:50 PM", + "tokens": 86, + }, + "parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "99", + "id": "question-4c5d0841-1206-463e-95d8-71f812877658", + "isAnswer": false, + "message_files": [], + "parentMessageId": "73bbad14-d915-499d-87bf-0df14d40779d", + }, + ], + "content": "I choose 83. What's your next number?", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "73bbad14-d915-499d-87bf-0df14d40779d", + "input": { + "inputs": {}, + "query": "58", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + ], + "message_files": [], + "more": { + "latency": "1.33", + "time": "09/11/2024 09:49 PM", + "tokens": 68, + }, + "parentMessageId": "question-73bbad14-d915-499d-87bf-0df14d40779d", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "58", + "id": "question-73bbad14-d915-499d-87bf-0df14d40779d", + "isAnswer": false, + "message_files": [], + "parentMessageId": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + }, + ], + "content": "Sure, I'll play! My number is 57. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "input": { + "inputs": {}, + "query": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.56", + "time": "09/11/2024 09:49 PM", + "tokens": 49, + }, + "parentMessageId": "question-ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "Let's play a game, I say a number , and you response me with another bigger, yet random-looking number. I'll start first, 38", + "id": "question-ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "isAnswer": false, + "message_files": [], + }, + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105791, + "files": [], + "id": "f9d7ff7c-3a3b-4d9a-a289-657817f4caff", + "message_id": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "observation": "", + "position": 1, + "thought": "Sure, I'll play! My number is 57. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105795, + "files": [], + "id": "f61a3fce-37ac-4f9d-9935-95f97e598dfe", + "message_id": "73bbad14-d915-499d-87bf-0df14d40779d", + "observation": "", + "position": 1, + "thought": "I choose 83. What's your next number?", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105799, + "files": [], + "id": "9730d587-9268-4683-9dd9-91a1cab9510b", + "message_id": "4c5d0841-1206-463e-95d8-71f812877658", + "observation": "", + "position": 1, + "thought": "I'll go with 112. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "I'll go with 112. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "4c5d0841-1206-463e-95d8-71f812877658", + "input": { + "inputs": {}, + "query": "99", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + { + "files": [], + "role": "user", + "text": "99", + }, + { + "files": [], + "role": "assistant", + "text": "I'll go with 112. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.49", + "time": "09/11/2024 09:50 PM", + "tokens": 86, + }, + "parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "99", + "id": "question-4c5d0841-1206-463e-95d8-71f812877658", + "isAnswer": false, + "message_files": [], + "parentMessageId": "73bbad14-d915-499d-87bf-0df14d40779d", + }, + ], + "content": "I choose 83. What's your next number?", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "73bbad14-d915-499d-87bf-0df14d40779d", + "input": { + "inputs": {}, + "query": "58", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + ], + "message_files": [], + "more": { + "latency": "1.33", + "time": "09/11/2024 09:49 PM", + "tokens": 68, + }, + "parentMessageId": "question-73bbad14-d915-499d-87bf-0df14d40779d", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "58", + "id": "question-73bbad14-d915-499d-87bf-0df14d40779d", + "isAnswer": false, + "message_files": [], + "parentMessageId": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + }, + ], + "content": "Sure, I'll play! My number is 57. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "input": { + "inputs": {}, + "query": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.56", + "time": "09/11/2024 09:49 PM", + "tokens": 49, + }, + "nextSibling": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "parentMessageId": "question-ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "prevSibling": undefined, + "siblingCount": 2, + "siblingIndex": 0, + "workflow_run_id": null, + }, + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105795, + "files": [], + "id": "f61a3fce-37ac-4f9d-9935-95f97e598dfe", + "message_id": "73bbad14-d915-499d-87bf-0df14d40779d", + "observation": "", + "position": 1, + "thought": "I choose 83. What's your next number?", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105799, + "files": [], + "id": "9730d587-9268-4683-9dd9-91a1cab9510b", + "message_id": "4c5d0841-1206-463e-95d8-71f812877658", + "observation": "", + "position": 1, + "thought": "I'll go with 112. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "I'll go with 112. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "4c5d0841-1206-463e-95d8-71f812877658", + "input": { + "inputs": {}, + "query": "99", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + { + "files": [], + "role": "user", + "text": "99", + }, + { + "files": [], + "role": "assistant", + "text": "I'll go with 112. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.49", + "time": "09/11/2024 09:50 PM", + "tokens": 86, + }, + "parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "99", + "id": "question-4c5d0841-1206-463e-95d8-71f812877658", + "isAnswer": false, + "message_files": [], + "parentMessageId": "73bbad14-d915-499d-87bf-0df14d40779d", + }, + ], + "content": "I choose 83. What's your next number?", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "73bbad14-d915-499d-87bf-0df14d40779d", + "input": { + "inputs": {}, + "query": "58", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + ], + "message_files": [], + "more": { + "latency": "1.33", + "time": "09/11/2024 09:49 PM", + "tokens": 68, + }, + "parentMessageId": "question-73bbad14-d915-499d-87bf-0df14d40779d", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "58", + "id": "question-73bbad14-d915-499d-87bf-0df14d40779d", + "isAnswer": false, + "message_files": [], + "parentMessageId": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + }, + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105795, + "files": [], + "id": "f61a3fce-37ac-4f9d-9935-95f97e598dfe", + "message_id": "73bbad14-d915-499d-87bf-0df14d40779d", + "observation": "", + "position": 1, + "thought": "I choose 83. What's your next number?", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105799, + "files": [], + "id": "9730d587-9268-4683-9dd9-91a1cab9510b", + "message_id": "4c5d0841-1206-463e-95d8-71f812877658", + "observation": "", + "position": 1, + "thought": "I'll go with 112. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "I'll go with 112. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "4c5d0841-1206-463e-95d8-71f812877658", + "input": { + "inputs": {}, + "query": "99", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + { + "files": [], + "role": "user", + "text": "99", + }, + { + "files": [], + "role": "assistant", + "text": "I'll go with 112. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.49", + "time": "09/11/2024 09:50 PM", + "tokens": 86, + }, + "parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "99", + "id": "question-4c5d0841-1206-463e-95d8-71f812877658", + "isAnswer": false, + "message_files": [], + "parentMessageId": "73bbad14-d915-499d-87bf-0df14d40779d", + }, + ], + "content": "I choose 83. What's your next number?", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "73bbad14-d915-499d-87bf-0df14d40779d", + "input": { + "inputs": {}, + "query": "58", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + ], + "message_files": [], + "more": { + "latency": "1.33", + "time": "09/11/2024 09:49 PM", + "tokens": 68, + }, + "parentMessageId": "question-73bbad14-d915-499d-87bf-0df14d40779d", + "prevSibling": undefined, + "siblingCount": 1, + "siblingIndex": 0, + "workflow_run_id": null, + }, + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105799, + "files": [], + "id": "9730d587-9268-4683-9dd9-91a1cab9510b", + "message_id": "4c5d0841-1206-463e-95d8-71f812877658", + "observation": "", + "position": 1, + "thought": "I'll go with 112. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "I'll go with 112. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "4c5d0841-1206-463e-95d8-71f812877658", + "input": { + "inputs": {}, + "query": "99", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + { + "files": [], + "role": "user", + "text": "99", + }, + { + "files": [], + "role": "assistant", + "text": "I'll go with 112. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.49", + "time": "09/11/2024 09:50 PM", + "tokens": 86, + }, + "parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "99", + "id": "question-4c5d0841-1206-463e-95d8-71f812877658", + "isAnswer": false, + "message_files": [], + "parentMessageId": "73bbad14-d915-499d-87bf-0df14d40779d", + }, + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105799, + "files": [], + "id": "9730d587-9268-4683-9dd9-91a1cab9510b", + "message_id": "4c5d0841-1206-463e-95d8-71f812877658", + "observation": "", + "position": 1, + "thought": "I'll go with 112. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "I'll go with 112. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "4c5d0841-1206-463e-95d8-71f812877658", + "input": { + "inputs": {}, + "query": "99", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + { + "files": [], + "role": "user", + "text": "99", + }, + { + "files": [], + "role": "assistant", + "text": "I'll go with 112. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.49", + "time": "09/11/2024 09:50 PM", + "tokens": 86, + }, + "parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658", + "prevSibling": undefined, + "siblingCount": 1, + "siblingIndex": 0, + "workflow_run_id": null, + }, +] +`; + +exports[`build chat item tree and get thread messages should get thread messages from tree6, using the last message as target 1`] = ` +[ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105809, + "files": [], + "id": "1019cd79-d141-4f9f-880a-fc1441cfd802", + "message_id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "observation": "", + "position": 1, + "thought": "Sure! My number is 54. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105822, + "files": [], + "id": "0773bec7-b992-4a53-92b2-20ebaeae8798", + "message_id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "observation": "", + "position": 1, + "thought": "My number is 4729. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 4729. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4729. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.30", + "time": "09/11/2024 09:50 PM", + "tokens": 66, + }, + "parentMessageId": "question-324bce32-c98c-435d-a66b-bac974ebb5ed", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "3306", + "id": "question-324bce32-c98c-435d-a66b-bac974ebb5ed", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + }, + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726107812, + "files": [], + "id": "5ca650f3-982c-4399-8b95-9ea241c76707", + "message_id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "observation": "", + "position": 1, + "thought": "My number is 4821. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726111024, + "files": [], + "id": "095cacab-afad-4387-a41d-1662578b8b13", + "message_id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "observation": "", + "position": 1, + "thought": "My number is 1456. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 1456. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "input": { + "inputs": {}, + "query": "1003", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "1003", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 1456. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.38", + "time": "09/11/2024 11:17 PM", + "tokens": 86, + }, + "parentMessageId": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "1003", + "id": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "isAnswer": false, + "message_files": [], + "parentMessageId": "684b5396-4e91-4043-88e9-aabe48b21acc", + }, + ], + "content": "My number is 4821. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.48", + "time": "09/11/2024 10:23 PM", + "tokens": 66, + }, + "parentMessageId": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "siblingIndex": 1, + "workflow_run_id": null, + }, + ], + "content": "3306", + "id": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + }, + ], + "content": "Sure! My number is 54. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "input": { + "inputs": {}, + "query": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.52", + "time": "09/11/2024 09:50 PM", + "tokens": 46, + }, + "parentMessageId": "question-cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "siblingIndex": 1, + "workflow_run_id": null, + }, + ], + "content": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "id": "question-cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "isAnswer": false, + "message_files": [], + }, + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105809, + "files": [], + "id": "1019cd79-d141-4f9f-880a-fc1441cfd802", + "message_id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "observation": "", + "position": 1, + "thought": "Sure! My number is 54. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105822, + "files": [], + "id": "0773bec7-b992-4a53-92b2-20ebaeae8798", + "message_id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "observation": "", + "position": 1, + "thought": "My number is 4729. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 4729. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4729. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.30", + "time": "09/11/2024 09:50 PM", + "tokens": 66, + }, + "parentMessageId": "question-324bce32-c98c-435d-a66b-bac974ebb5ed", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "3306", + "id": "question-324bce32-c98c-435d-a66b-bac974ebb5ed", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + }, + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726107812, + "files": [], + "id": "5ca650f3-982c-4399-8b95-9ea241c76707", + "message_id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "observation": "", + "position": 1, + "thought": "My number is 4821. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726111024, + "files": [], + "id": "095cacab-afad-4387-a41d-1662578b8b13", + "message_id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "observation": "", + "position": 1, + "thought": "My number is 1456. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 1456. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "input": { + "inputs": {}, + "query": "1003", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "1003", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 1456. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.38", + "time": "09/11/2024 11:17 PM", + "tokens": 86, + }, + "parentMessageId": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "1003", + "id": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "isAnswer": false, + "message_files": [], + "parentMessageId": "684b5396-4e91-4043-88e9-aabe48b21acc", + }, + ], + "content": "My number is 4821. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.48", + "time": "09/11/2024 10:23 PM", + "tokens": 66, + }, + "parentMessageId": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "siblingIndex": 1, + "workflow_run_id": null, + }, + ], + "content": "3306", + "id": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + }, + ], + "content": "Sure! My number is 54. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "input": { + "inputs": {}, + "query": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.52", + "time": "09/11/2024 09:50 PM", + "tokens": 46, + }, + "nextSibling": undefined, + "parentMessageId": "question-cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "prevSibling": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "siblingCount": 2, + "siblingIndex": 1, + "workflow_run_id": null, + }, + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726107812, + "files": [], + "id": "5ca650f3-982c-4399-8b95-9ea241c76707", + "message_id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "observation": "", + "position": 1, + "thought": "My number is 4821. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726111024, + "files": [], + "id": "095cacab-afad-4387-a41d-1662578b8b13", + "message_id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "observation": "", + "position": 1, + "thought": "My number is 1456. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 1456. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "input": { + "inputs": {}, + "query": "1003", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "1003", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 1456. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.38", + "time": "09/11/2024 11:17 PM", + "tokens": 86, + }, + "parentMessageId": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "1003", + "id": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "isAnswer": false, + "message_files": [], + "parentMessageId": "684b5396-4e91-4043-88e9-aabe48b21acc", + }, + ], + "content": "My number is 4821. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.48", + "time": "09/11/2024 10:23 PM", + "tokens": 66, + }, + "parentMessageId": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "siblingIndex": 1, + "workflow_run_id": null, + }, + ], + "content": "3306", + "id": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + }, + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726107812, + "files": [], + "id": "5ca650f3-982c-4399-8b95-9ea241c76707", + "message_id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "observation": "", + "position": 1, + "thought": "My number is 4821. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726111024, + "files": [], + "id": "095cacab-afad-4387-a41d-1662578b8b13", + "message_id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "observation": "", + "position": 1, + "thought": "My number is 1456. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 1456. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "input": { + "inputs": {}, + "query": "1003", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "1003", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 1456. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.38", + "time": "09/11/2024 11:17 PM", + "tokens": 86, + }, + "parentMessageId": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "1003", + "id": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "isAnswer": false, + "message_files": [], + "parentMessageId": "684b5396-4e91-4043-88e9-aabe48b21acc", + }, + ], + "content": "My number is 4821. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.48", + "time": "09/11/2024 10:23 PM", + "tokens": 66, + }, + "nextSibling": undefined, + "parentMessageId": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "prevSibling": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "siblingCount": 2, + "siblingIndex": 1, + "workflow_run_id": null, + }, + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726111024, + "files": [], + "id": "095cacab-afad-4387-a41d-1662578b8b13", + "message_id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "observation": "", + "position": 1, + "thought": "My number is 1456. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 1456. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "input": { + "inputs": {}, + "query": "1003", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "1003", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 1456. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.38", + "time": "09/11/2024 11:17 PM", + "tokens": 86, + }, + "parentMessageId": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "1003", + "id": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "isAnswer": false, + "message_files": [], + "parentMessageId": "684b5396-4e91-4043-88e9-aabe48b21acc", + }, + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726111024, + "files": [], + "id": "095cacab-afad-4387-a41d-1662578b8b13", + "message_id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "observation": "", + "position": 1, + "thought": "My number is 1456. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 1456. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "input": { + "inputs": {}, + "query": "1003", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "1003", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 1456. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.38", + "time": "09/11/2024 11:17 PM", + "tokens": 86, + }, + "nextSibling": undefined, + "parentMessageId": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "prevSibling": undefined, + "siblingCount": 1, + "siblingIndex": 0, + "workflow_run_id": null, + }, +] +`; + +exports[`build chat item tree and get thread messages should work with real world messages 1`] = ` +[ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105791, + "files": [], + "id": "f9d7ff7c-3a3b-4d9a-a289-657817f4caff", + "message_id": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "observation": "", + "position": 1, + "thought": "Sure, I'll play! My number is 57. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105795, + "files": [], + "id": "f61a3fce-37ac-4f9d-9935-95f97e598dfe", + "message_id": "73bbad14-d915-499d-87bf-0df14d40779d", + "observation": "", + "position": 1, + "thought": "I choose 83. What's your next number?", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105799, + "files": [], + "id": "9730d587-9268-4683-9dd9-91a1cab9510b", + "message_id": "4c5d0841-1206-463e-95d8-71f812877658", + "observation": "", + "position": 1, + "thought": "I'll go with 112. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "I'll go with 112. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "4c5d0841-1206-463e-95d8-71f812877658", + "input": { + "inputs": {}, + "query": "99", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + { + "files": [], + "role": "user", + "text": "99", + }, + { + "files": [], + "role": "assistant", + "text": "I'll go with 112. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.49", + "time": "09/11/2024 09:50 PM", + "tokens": 86, + }, + "parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "99", + "id": "question-4c5d0841-1206-463e-95d8-71f812877658", + "isAnswer": false, + "message_files": [], + "parentMessageId": "73bbad14-d915-499d-87bf-0df14d40779d", + }, + ], + "content": "I choose 83. What's your next number?", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "73bbad14-d915-499d-87bf-0df14d40779d", + "input": { + "inputs": {}, + "query": "58", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + ], + "message_files": [], + "more": { + "latency": "1.33", + "time": "09/11/2024 09:49 PM", + "tokens": 68, + }, + "parentMessageId": "question-73bbad14-d915-499d-87bf-0df14d40779d", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "58", + "id": "question-73bbad14-d915-499d-87bf-0df14d40779d", + "isAnswer": false, + "message_files": [], + "parentMessageId": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + }, + ], + "content": "Sure, I'll play! My number is 57. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "input": { + "inputs": {}, + "query": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.56", + "time": "09/11/2024 09:49 PM", + "tokens": 49, + }, + "parentMessageId": "question-ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "Let's play a game, I say a number , and you response me with another bigger, yet random-looking number. I'll start first, 38", + "id": "question-ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "isAnswer": false, + "message_files": [], + }, + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105809, + "files": [], + "id": "1019cd79-d141-4f9f-880a-fc1441cfd802", + "message_id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "observation": "", + "position": 1, + "thought": "Sure! My number is 54. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105822, + "files": [], + "id": "0773bec7-b992-4a53-92b2-20ebaeae8798", + "message_id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "observation": "", + "position": 1, + "thought": "My number is 4729. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 4729. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4729. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.30", + "time": "09/11/2024 09:50 PM", + "tokens": 66, + }, + "parentMessageId": "question-324bce32-c98c-435d-a66b-bac974ebb5ed", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "3306", + "id": "question-324bce32-c98c-435d-a66b-bac974ebb5ed", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + }, + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726107812, + "files": [], + "id": "5ca650f3-982c-4399-8b95-9ea241c76707", + "message_id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "observation": "", + "position": 1, + "thought": "My number is 4821. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726111024, + "files": [], + "id": "095cacab-afad-4387-a41d-1662578b8b13", + "message_id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "observation": "", + "position": 1, + "thought": "My number is 1456. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 1456. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "input": { + "inputs": {}, + "query": "1003", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "1003", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 1456. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.38", + "time": "09/11/2024 11:17 PM", + "tokens": 86, + }, + "parentMessageId": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "1003", + "id": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "isAnswer": false, + "message_files": [], + "parentMessageId": "684b5396-4e91-4043-88e9-aabe48b21acc", + }, + ], + "content": "My number is 4821. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.48", + "time": "09/11/2024 10:23 PM", + "tokens": 66, + }, + "parentMessageId": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "siblingIndex": 1, + "workflow_run_id": null, + }, + ], + "content": "3306", + "id": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + }, + ], + "content": "Sure! My number is 54. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "input": { + "inputs": {}, + "query": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.52", + "time": "09/11/2024 09:50 PM", + "tokens": 46, + }, + "parentMessageId": "question-cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "siblingIndex": 1, + "workflow_run_id": null, + }, + ], + "content": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "id": "question-cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "isAnswer": false, + "message_files": [], + }, +] +`; diff --git a/web/app/components/base/chat/__tests__/branchedTestMessages.json b/web/app/components/base/chat/__tests__/branchedTestMessages.json new file mode 100644 index 0000000000..30e0a82cb5 --- /dev/null +++ b/web/app/components/base/chat/__tests__/branchedTestMessages.json @@ -0,0 +1,42 @@ +[ + { + "id": "question-1", + "isAnswer": false, + "parentMessageId": null + }, + { + "id": "1", + "isAnswer": true, + "parentMessageId": "question-1" + }, + { + "id": "question-2", + "isAnswer": false, + "parentMessageId": "1" + }, + { + "id": "2", + "isAnswer": true, + "parentMessageId": "question-2" + }, + { + "id": "question-3", + "isAnswer": false, + "parentMessageId": "2" + }, + { + "id": "3", + "isAnswer": true, + "parentMessageId": "question-3" + }, + { + "id": "question-4", + "isAnswer": false, + "parentMessageId": "1" + }, + { + "id": "4", + "isAnswer": true, + "parentMessageId": "question-4" + } +] diff --git a/web/app/components/base/chat/__tests__/legacyTestMessages.json b/web/app/components/base/chat/__tests__/legacyTestMessages.json new file mode 100644 index 0000000000..2dab58985a --- /dev/null +++ b/web/app/components/base/chat/__tests__/legacyTestMessages.json @@ -0,0 +1,42 @@ +[ + { + "id": "question-1", + "isAnswer": false, + "parentMessageId": "00000000-0000-0000-0000-000000000000" + }, + { + "id": "1", + "isAnswer": true, + "parentMessageId": "question-1" + }, + { + "id": "question-2", + "isAnswer": false, + "parentMessageId": "00000000-0000-0000-0000-000000000000" + }, + { + "id": "2", + "isAnswer": true, + "parentMessageId": "question-2" + }, + { + "id": "question-3", + "isAnswer": false, + "parentMessageId": "00000000-0000-0000-0000-000000000000" + }, + { + "id": "3", + "isAnswer": true, + "parentMessageId": "question-3" + }, + { + "id": "question-4", + "isAnswer": false, + "parentMessageId": "00000000-0000-0000-0000-000000000000" + }, + { + "id": "4", + "isAnswer": true, + "parentMessageId": "question-4" + } +] diff --git a/web/app/components/base/chat/__tests__/mixedTestMessages.json b/web/app/components/base/chat/__tests__/mixedTestMessages.json new file mode 100644 index 0000000000..14789d9518 --- /dev/null +++ b/web/app/components/base/chat/__tests__/mixedTestMessages.json @@ -0,0 +1,42 @@ +[ + { + "id": "question-1", + "isAnswer": false, + "parentMessageId": "00000000-0000-0000-0000-000000000000" + }, + { + "id": "1", + "isAnswer": true, + "parentMessageId": "question-1" + }, + { + "id": "question-2", + "isAnswer": false, + "parentMessageId": "00000000-0000-0000-0000-000000000000" + }, + { + "id": "2", + "isAnswer": true, + "parentMessageId": "question-2" + }, + { + "id": "question-3", + "isAnswer": false, + "parentMessageId": "2" + }, + { + "id": "3", + "isAnswer": true, + "parentMessageId": "question-3" + }, + { + "id": "question-4", + "isAnswer": false, + "parentMessageId": "1" + }, + { + "id": "4", + "isAnswer": true, + "parentMessageId": "question-4" + } +] diff --git a/web/app/components/base/chat/__tests__/multiRootNodesMessages.json b/web/app/components/base/chat/__tests__/multiRootNodesMessages.json new file mode 100644 index 0000000000..782ccb7f94 --- /dev/null +++ b/web/app/components/base/chat/__tests__/multiRootNodesMessages.json @@ -0,0 +1,52 @@ +[ + { + "id": "question-1", + "isAnswer": false, + "parentMessageId": null + }, + { + "id": "1", + "isAnswer": true, + "parentMessageId": "question-1" + }, + { + "id": "question-2", + "isAnswer": false, + "parentMessageId": "1" + }, + { + "id": "2", + "isAnswer": true, + "parentMessageId": "question-2" + }, + { + "id": "question-3", + "isAnswer": false, + "parentMessageId": "2" + }, + { + "id": "3", + "isAnswer": true, + "parentMessageId": "question-3" + }, + { + "id": "question-4", + "isAnswer": false, + "parentMessageId": "1" + }, + { + "id": "4", + "isAnswer": true, + "parentMessageId": "question-4" + }, + { + "id": "question-5", + "isAnswer": false, + "parentMessageId": null + }, + { + "id": "5", + "isAnswer": true, + "parentMessageId": "question-5" + } +] diff --git a/web/app/components/base/chat/__tests__/multiRootNodesWithLegacyTestMessages.json b/web/app/components/base/chat/__tests__/multiRootNodesWithLegacyTestMessages.json new file mode 100644 index 0000000000..5eadc726e5 --- /dev/null +++ b/web/app/components/base/chat/__tests__/multiRootNodesWithLegacyTestMessages.json @@ -0,0 +1,52 @@ +[ + { + "id": "question-1", + "isAnswer": false, + "parentMessageId": "00000000-0000-0000-0000-000000000000" + }, + { + "id": "1", + "isAnswer": true, + "parentMessageId": "question-1" + }, + { + "id": "question-2", + "isAnswer": false, + "parentMessageId": "00000000-0000-0000-0000-000000000000" + }, + { + "id": "2", + "isAnswer": true, + "parentMessageId": "question-2" + }, + { + "id": "question-3", + "isAnswer": false, + "parentMessageId": "00000000-0000-0000-0000-000000000000" + }, + { + "id": "3", + "isAnswer": true, + "parentMessageId": "question-3" + }, + { + "id": "question-4", + "isAnswer": false, + "parentMessageId": "1" + }, + { + "id": "4", + "isAnswer": true, + "parentMessageId": "question-4" + }, + { + "id": "question-5", + "isAnswer": false, + "parentMessageId": null + }, + { + "id": "5", + "isAnswer": true, + "parentMessageId": "question-5" + } +] diff --git a/web/app/components/base/chat/__tests__/realWorldMessages.json b/web/app/components/base/chat/__tests__/realWorldMessages.json new file mode 100644 index 0000000000..858052c77f --- /dev/null +++ b/web/app/components/base/chat/__tests__/realWorldMessages.json @@ -0,0 +1,441 @@ +[ + { + "id": "question-ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "content": "Let's play a game, I say a number , and you response me with another bigger, yet random-looking number. I'll start first, 38", + "isAnswer": false, + "message_files": [] + }, + { + "id": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "content": "Sure, I'll play! My number is 57. Your turn!", + "agent_thoughts": [ + { + "id": "f9d7ff7c-3a3b-4d9a-a289-657817f4caff", + "chain_id": null, + "message_id": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b", + "position": 1, + "thought": "Sure, I'll play! My number is 57. Your turn!", + "tool": "", + "tool_labels": {}, + "tool_input": "", + "created_at": 1726105791, + "observation": "", + "files": [] + } + ], + "feedbackDisabled": false, + "isAnswer": true, + "message_files": [], + "log": [ + { + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "files": [] + }, + { + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + "files": [] + } + ], + "workflow_run_id": null, + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "input": { + "inputs": {}, + "query": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38" + }, + "more": { + "time": "09/11/2024 09:49 PM", + "tokens": 49, + "latency": "1.56" + }, + "parentMessageId": "question-ff4c2b43-48a5-47ad-9dc5-08b34ddba61b" + }, + { + "id": "question-73bbad14-d915-499d-87bf-0df14d40779d", + "content": "58", + "isAnswer": false, + "message_files": [], + "parentMessageId": "ff4c2b43-48a5-47ad-9dc5-08b34ddba61b" + }, + { + "id": "73bbad14-d915-499d-87bf-0df14d40779d", + "content": "I choose 83. What's your next number?", + "agent_thoughts": [ + { + "id": "f61a3fce-37ac-4f9d-9935-95f97e598dfe", + "chain_id": null, + "message_id": "73bbad14-d915-499d-87bf-0df14d40779d", + "position": 1, + "thought": "I choose 83. What's your next number?", + "tool": "", + "tool_labels": {}, + "tool_input": "", + "created_at": 1726105795, + "observation": "", + "files": [] + } + ], + "feedbackDisabled": false, + "isAnswer": true, + "message_files": [], + "log": [ + { + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "files": [] + }, + { + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + "files": [] + }, + { + "role": "user", + "text": "58", + "files": [] + }, + { + "role": "assistant", + "text": "I choose 83. What's your next number?", + "files": [] + } + ], + "workflow_run_id": null, + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "input": { + "inputs": {}, + "query": "58" + }, + "more": { + "time": "09/11/2024 09:49 PM", + "tokens": 68, + "latency": "1.33" + }, + "parentMessageId": "question-73bbad14-d915-499d-87bf-0df14d40779d" + }, + { + "id": "question-4c5d0841-1206-463e-95d8-71f812877658", + "content": "99", + "isAnswer": false, + "message_files": [], + "parentMessageId": "73bbad14-d915-499d-87bf-0df14d40779d" + }, + { + "id": "4c5d0841-1206-463e-95d8-71f812877658", + "content": "I'll go with 112. Your turn!", + "agent_thoughts": [ + { + "id": "9730d587-9268-4683-9dd9-91a1cab9510b", + "chain_id": null, + "message_id": "4c5d0841-1206-463e-95d8-71f812877658", + "position": 1, + "thought": "I'll go with 112. Your turn!", + "tool": "", + "tool_labels": {}, + "tool_input": "", + "created_at": 1726105799, + "observation": "", + "files": [] + } + ], + "feedbackDisabled": false, + "isAnswer": true, + "message_files": [], + "log": [ + { + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "files": [] + }, + { + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + "files": [] + }, + { + "role": "user", + "text": "58", + "files": [] + }, + { + "role": "assistant", + "text": "I choose 83. What's your next number?", + "files": [] + }, + { + "role": "user", + "text": "99", + "files": [] + }, + { + "role": "assistant", + "text": "I'll go with 112. Your turn!", + "files": [] + } + ], + "workflow_run_id": null, + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "input": { + "inputs": {}, + "query": "99" + }, + "more": { + "time": "09/11/2024 09:50 PM", + "tokens": 86, + "latency": "1.49" + }, + "parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658" + }, + { + "id": "question-cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "content": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "isAnswer": false, + "message_files": [] + }, + { + "id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "content": "Sure! My number is 54. Your turn!", + "agent_thoughts": [ + { + "id": "1019cd79-d141-4f9f-880a-fc1441cfd802", + "chain_id": null, + "message_id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "position": 1, + "thought": "Sure! My number is 54. Your turn!", + "tool": "", + "tool_labels": {}, + "tool_input": "", + "created_at": 1726105809, + "observation": "", + "files": [] + } + ], + "feedbackDisabled": false, + "isAnswer": true, + "message_files": [], + "log": [ + { + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "files": [] + }, + { + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + "files": [] + } + ], + "workflow_run_id": null, + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "input": { + "inputs": {}, + "query": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38" + }, + "more": { + "time": "09/11/2024 09:50 PM", + "tokens": 46, + "latency": "1.52" + }, + "parentMessageId": "question-cd5affb0-7bc2-4a6f-be7e-25e74595c9dd" + }, + { + "id": "question-324bce32-c98c-435d-a66b-bac974ebb5ed", + "content": "3306", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd" + }, + { + "id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "content": "My number is 4729. Your turn!", + "agent_thoughts": [ + { + "id": "0773bec7-b992-4a53-92b2-20ebaeae8798", + "chain_id": null, + "message_id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "position": 1, + "thought": "My number is 4729. Your turn!", + "tool": "", + "tool_labels": {}, + "tool_input": "", + "created_at": 1726105822, + "observation": "", + "files": [] + } + ], + "feedbackDisabled": false, + "isAnswer": true, + "message_files": [], + "log": [ + { + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "files": [] + }, + { + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + "files": [] + }, + { + "role": "user", + "text": "3306", + "files": [] + }, + { + "role": "assistant", + "text": "My number is 4729. Your turn!", + "files": [] + } + ], + "workflow_run_id": null, + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "input": { + "inputs": {}, + "query": "3306" + }, + "more": { + "time": "09/11/2024 09:50 PM", + "tokens": 66, + "latency": "1.30" + }, + "parentMessageId": "question-324bce32-c98c-435d-a66b-bac974ebb5ed" + }, + { + "id": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "content": "3306", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd" + }, + { + "id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "content": "My number is 4821. Your turn!", + "agent_thoughts": [ + { + "id": "5ca650f3-982c-4399-8b95-9ea241c76707", + "chain_id": null, + "message_id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "position": 1, + "thought": "My number is 4821. Your turn!", + "tool": "", + "tool_labels": {}, + "tool_input": "", + "created_at": 1726107812, + "observation": "", + "files": [] + } + ], + "feedbackDisabled": false, + "isAnswer": true, + "message_files": [], + "log": [ + { + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "files": [] + }, + { + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + "files": [] + }, + { + "role": "user", + "text": "3306", + "files": [] + }, + { + "role": "assistant", + "text": "My number is 4821. Your turn!", + "files": [] + } + ], + "workflow_run_id": null, + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "input": { + "inputs": {}, + "query": "3306" + }, + "more": { + "time": "09/11/2024 10:23 PM", + "tokens": 66, + "latency": "1.48" + }, + "parentMessageId": "question-684b5396-4e91-4043-88e9-aabe48b21acc" + }, + { + "id": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "content": "1003", + "isAnswer": false, + "message_files": [], + "parentMessageId": "684b5396-4e91-4043-88e9-aabe48b21acc" + }, + { + "id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "content": "My number is 1456. Your turn!", + "agent_thoughts": [ + { + "id": "095cacab-afad-4387-a41d-1662578b8b13", + "chain_id": null, + "message_id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "position": 1, + "thought": "My number is 1456. Your turn!", + "tool": "", + "tool_labels": {}, + "tool_input": "", + "created_at": 1726111024, + "observation": "", + "files": [] + } + ], + "feedbackDisabled": false, + "isAnswer": true, + "message_files": [], + "log": [ + { + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "files": [] + }, + { + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + "files": [] + }, + { + "role": "user", + "text": "3306", + "files": [] + }, + { + "role": "assistant", + "text": "My number is 4821. Your turn!", + "files": [] + }, + { + "role": "user", + "text": "1003", + "files": [] + }, + { + "role": "assistant", + "text": "My number is 1456. Your turn!", + "files": [] + } + ], + "workflow_run_id": null, + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "input": { + "inputs": {}, + "query": "1003" + }, + "more": { + "time": "09/11/2024 11:17 PM", + "tokens": 86, + "latency": "1.38" + }, + "parentMessageId": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c" + } +] diff --git a/web/app/components/base/chat/__tests__/utils.spec.ts b/web/app/components/base/chat/__tests__/utils.spec.ts new file mode 100644 index 0000000000..c602ac8a99 --- /dev/null +++ b/web/app/components/base/chat/__tests__/utils.spec.ts @@ -0,0 +1,258 @@ +import { get } from 'lodash' +import { buildChatItemTree, getThreadMessages } from '../utils' +import type { ChatItemInTree } from '../types' +import branchedTestMessages from './branchedTestMessages.json' +import legacyTestMessages from './legacyTestMessages.json' +import mixedTestMessages from './mixedTestMessages.json' +import multiRootNodesMessages from './multiRootNodesMessages.json' +import multiRootNodesWithLegacyTestMessages from './multiRootNodesWithLegacyTestMessages.json' +import realWorldMessages from './realWorldMessages.json' + +function visitNode(tree: ChatItemInTree | ChatItemInTree[], path: string): ChatItemInTree { + return get(tree, path) +} + +describe('build chat item tree and get thread messages', () => { + const tree1 = buildChatItemTree(branchedTestMessages as ChatItemInTree[]) + + it('should build chat item tree1', () => { + const a1 = visitNode(tree1, '0.children.0') + expect(a1.id).toBe('1') + expect(a1.children).toHaveLength(2) + + const a2 = visitNode(a1, 'children.0.children.0') + expect(a2.id).toBe('2') + expect(a2.siblingIndex).toBe(0) + + const a3 = visitNode(a2, 'children.0.children.0') + expect(a3.id).toBe('3') + + const a4 = visitNode(a1, 'children.1.children.0') + expect(a4.id).toBe('4') + expect(a4.siblingIndex).toBe(1) + }) + + it('should get thread messages from tree1, using the last message as the target', () => { + const threadChatItems1_1 = getThreadMessages(tree1) + expect(threadChatItems1_1).toHaveLength(4) + + const q1 = visitNode(threadChatItems1_1, '0') + const a1 = visitNode(threadChatItems1_1, '1') + const q4 = visitNode(threadChatItems1_1, '2') + const a4 = visitNode(threadChatItems1_1, '3') + + expect(q1.id).toBe('question-1') + expect(a1.id).toBe('1') + expect(q4.id).toBe('question-4') + expect(a4.id).toBe('4') + + expect(a4.siblingCount).toBe(2) + expect(a4.siblingIndex).toBe(1) + }) + + it('should get thread messages from tree1, using the message with id 3 as the target', () => { + const threadChatItems1_2 = getThreadMessages(tree1, '3') + expect(threadChatItems1_2).toHaveLength(6) + + const q1 = visitNode(threadChatItems1_2, '0') + const a1 = visitNode(threadChatItems1_2, '1') + const q2 = visitNode(threadChatItems1_2, '2') + const a2 = visitNode(threadChatItems1_2, '3') + const q3 = visitNode(threadChatItems1_2, '4') + const a3 = visitNode(threadChatItems1_2, '5') + + expect(q1.id).toBe('question-1') + expect(a1.id).toBe('1') + expect(q2.id).toBe('question-2') + expect(a2.id).toBe('2') + expect(q3.id).toBe('question-3') + expect(a3.id).toBe('3') + + expect(a2.siblingCount).toBe(2) + expect(a2.siblingIndex).toBe(0) + }) + + const tree2 = buildChatItemTree(legacyTestMessages as ChatItemInTree[]) + it('should work with legacy chat items', () => { + expect(tree2).toHaveLength(1) + const q1 = visitNode(tree2, '0') + const a1 = visitNode(q1, 'children.0') + const q2 = visitNode(a1, 'children.0') + const a2 = visitNode(q2, 'children.0') + const q3 = visitNode(a2, 'children.0') + const a3 = visitNode(q3, 'children.0') + const q4 = visitNode(a3, 'children.0') + const a4 = visitNode(q4, 'children.0') + + expect(q1.id).toBe('question-1') + expect(a1.id).toBe('1') + expect(q2.id).toBe('question-2') + expect(a2.id).toBe('2') + expect(q3.id).toBe('question-3') + expect(a3.id).toBe('3') + expect(q4.id).toBe('question-4') + expect(a4.id).toBe('4') + }) + + it('should get thread messages from tree2, using the last message as the target', () => { + const threadMessages2 = getThreadMessages(tree2) + expect(threadMessages2).toHaveLength(8) + + const q1 = visitNode(threadMessages2, '0') + const a1 = visitNode(threadMessages2, '1') + const q2 = visitNode(threadMessages2, '2') + const a2 = visitNode(threadMessages2, '3') + const q3 = visitNode(threadMessages2, '4') + const a3 = visitNode(threadMessages2, '5') + const q4 = visitNode(threadMessages2, '6') + const a4 = visitNode(threadMessages2, '7') + + expect(q1.id).toBe('question-1') + expect(a1.id).toBe('1') + expect(q2.id).toBe('question-2') + expect(a2.id).toBe('2') + expect(q3.id).toBe('question-3') + expect(a3.id).toBe('3') + expect(q4.id).toBe('question-4') + expect(a4.id).toBe('4') + + expect(a1.siblingCount).toBe(1) + expect(a1.siblingIndex).toBe(0) + expect(a2.siblingCount).toBe(1) + expect(a2.siblingIndex).toBe(0) + expect(a3.siblingCount).toBe(1) + expect(a3.siblingIndex).toBe(0) + expect(a4.siblingCount).toBe(1) + expect(a4.siblingIndex).toBe(0) + }) + + const tree3 = buildChatItemTree(mixedTestMessages as ChatItemInTree[]) + it('should build mixed chat items tree', () => { + expect(tree3).toHaveLength(1) + + const a1 = visitNode(tree3, '0.children.0') + expect(a1.id).toBe('1') + expect(a1.children).toHaveLength(2) + + const a2 = visitNode(a1, 'children.0.children.0') + expect(a2.id).toBe('2') + expect(a2.siblingIndex).toBe(0) + + const a3 = visitNode(a2, 'children.0.children.0') + expect(a3.id).toBe('3') + + const a4 = visitNode(a1, 'children.1.children.0') + expect(a4.id).toBe('4') + expect(a4.siblingIndex).toBe(1) + }) + + it('should get thread messages from tree3, using the last message as the target', () => { + const threadMessages3_1 = getThreadMessages(tree3) + expect(threadMessages3_1).toHaveLength(4) + + const q1 = visitNode(threadMessages3_1, '0') + const a1 = visitNode(threadMessages3_1, '1') + const q4 = visitNode(threadMessages3_1, '2') + const a4 = visitNode(threadMessages3_1, '3') + + expect(q1.id).toBe('question-1') + expect(a1.id).toBe('1') + expect(q4.id).toBe('question-4') + expect(a4.id).toBe('4') + + expect(a4.siblingCount).toBe(2) + expect(a4.siblingIndex).toBe(1) + }) + + it('should get thread messages from tree3, using the message with id 3 as the target', () => { + const threadMessages3_2 = getThreadMessages(tree3, '3') + expect(threadMessages3_2).toHaveLength(6) + + const q1 = visitNode(threadMessages3_2, '0') + const a1 = visitNode(threadMessages3_2, '1') + const q2 = visitNode(threadMessages3_2, '2') + const a2 = visitNode(threadMessages3_2, '3') + const q3 = visitNode(threadMessages3_2, '4') + const a3 = visitNode(threadMessages3_2, '5') + + expect(q1.id).toBe('question-1') + expect(a1.id).toBe('1') + expect(q2.id).toBe('question-2') + expect(a2.id).toBe('2') + expect(q3.id).toBe('question-3') + expect(a3.id).toBe('3') + + expect(a2.siblingCount).toBe(2) + expect(a2.siblingIndex).toBe(0) + }) + + const tree4 = buildChatItemTree(multiRootNodesMessages as ChatItemInTree[]) + it('should build multi root nodes chat items tree', () => { + expect(tree4).toHaveLength(2) + + const a5 = visitNode(tree4, '1.children.0') + expect(a5.id).toBe('5') + expect(a5.siblingIndex).toBe(1) + }) + + it('should get thread messages from tree4, using the last message as the target', () => { + const threadMessages4 = getThreadMessages(tree4) + expect(threadMessages4).toHaveLength(2) + + const a1 = visitNode(threadMessages4, '0.children.0') + expect(a1.id).toBe('5') + }) + + it('should get thread messages from tree4, using the message with id 2 as the target', () => { + const threadMessages4_1 = getThreadMessages(tree4, '2') + expect(threadMessages4_1).toHaveLength(6) + const a1 = visitNode(threadMessages4_1, '1') + expect(a1.id).toBe('1') + const a2 = visitNode(threadMessages4_1, '3') + expect(a2.id).toBe('2') + const a3 = visitNode(threadMessages4_1, '5') + expect(a3.id).toBe('3') + }) + + const tree5 = buildChatItemTree(multiRootNodesWithLegacyTestMessages as ChatItemInTree[]) + it('should work with multi root nodes chat items with legacy chat items', () => { + expect(tree5).toHaveLength(2) + + const q5 = visitNode(tree5, '1') + expect(q5.id).toBe('question-5') + expect(q5.parentMessageId).toBe(null) + + const a5 = visitNode(q5, 'children.0') + expect(a5.id).toBe('5') + expect(a5.children).toHaveLength(0) + }) + + it('should get thread messages from tree5, using the last message as the target', () => { + const threadMessages5 = getThreadMessages(tree5) + expect(threadMessages5).toHaveLength(2) + + const q5 = visitNode(threadMessages5, '0') + const a5 = visitNode(threadMessages5, '1') + + expect(q5.id).toBe('question-5') + expect(a5.id).toBe('5') + + expect(a5.siblingCount).toBe(2) + expect(a5.siblingIndex).toBe(1) + }) + + const tree6 = buildChatItemTree(realWorldMessages as ChatItemInTree[]) + it('should work with real world messages', () => { + expect(tree6).toMatchSnapshot() + }) + + it ('should get thread messages from tree6, using the last message as target', () => { + const threadMessages6_1 = getThreadMessages(tree6) + expect(threadMessages6_1).toMatchSnapshot() + }) + + it ('should get thread messages from tree6, using specified message as target', () => { + const threadMessages6_2 = getThreadMessages(tree6, 'ff4c2b43-48a5-47ad-9dc5-08b34ddba61b') + expect(threadMessages6_2).toMatchSnapshot() + }) +}) diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index e632b14969..724ef78e75 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -40,6 +40,10 @@ const ChatWrapper = () => { return { ...config, + file_upload: { + ...(config as any).file_upload, + fileUploadConfig: (config as any).system_parameters, + }, supportFeedback: true, opening_statement: currentConversationId ? currentConversationItem?.introduction : (config as any).opening_statement, } as ChatConfig diff --git a/web/app/components/base/chat/chat-with-history/config-panel/form.tsx b/web/app/components/base/chat/chat-with-history/config-panel/form.tsx index 298d8ccd7f..1292edabd2 100644 --- a/web/app/components/base/chat/chat-with-history/config-panel/form.tsx +++ b/web/app/components/base/chat/chat-with-history/config-panel/form.tsx @@ -9,6 +9,7 @@ import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uplo const Form = () => { const { t } = useTranslation() const { + appParams, inputsForms, newConversationInputs, newConversationInputsRef, @@ -61,6 +62,7 @@ const Form = () => { allowed_file_extensions: form.allowed_file_extensions, allowed_file_upload_methods: form.allowed_file_upload_methods, number_limits: 1, + fileUploadConfig: (appParams as any).system_parameters, }} /> ) @@ -75,6 +77,7 @@ const Form = () => { allowed_file_extensions: form.allowed_file_extensions, allowed_file_upload_methods: form.allowed_file_upload_methods, number_limits: form.max_length, + fileUploadConfig: (appParams as any).system_parameters, }} /> ) diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index 50f51f521f..1ff390bd58 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -19,6 +19,7 @@ import Citation from '@/app/components/base/chat/chat/citation' import { EditTitle } from '@/app/components/app/annotation/edit-annotation-modal/edit-item' import type { AppData } from '@/models/share' import AnswerIcon from '@/app/components/base/answer-icon' +import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' import cn from '@/utils/classnames' import { FileList } from '@/app/components/base/file-uploader' @@ -34,6 +35,7 @@ type AnswerProps = { hideProcessDetail?: boolean appData?: AppData noChatInput?: boolean + switchSibling?: (siblingMessageId: string) => void } const Answer: FC = ({ item, @@ -47,6 +49,7 @@ const Answer: FC = ({ hideProcessDetail, appData, noChatInput, + switchSibling, }) => { const { t } = useTranslation() const { @@ -203,6 +206,23 @@ const Answer: FC = ({ ) } + {item.siblingCount && item.siblingCount > 1 && item.siblingIndex !== undefined &&
+ + {item.siblingIndex + 1} / {item.siblingCount} + +
} diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 742632a1ad..22020066b4 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -65,6 +65,7 @@ export type ChatProps = { hideProcessDetail?: boolean hideLogModal?: boolean themeBuilder?: ThemeBuilder + switchSibling?: (siblingMessageId: string) => void showFeatureBar?: boolean showFileUpload?: boolean onFeatureBarClick?: (state: boolean) => void @@ -100,6 +101,7 @@ const Chat: FC = ({ hideProcessDetail, hideLogModal, themeBuilder, + switchSibling, showFeatureBar, showFileUpload, onFeatureBarClick, @@ -232,6 +234,7 @@ const Chat: FC = ({ chatAnswerContainerInner={chatAnswerContainerInner} hideProcessDetail={hideProcessDetail} noChatInput={noChatInput} + switchSibling={switchSibling} /> ) } diff --git a/web/app/components/base/chat/chat/type.ts b/web/app/components/base/chat/chat/type.ts index 40cc32e859..7f22ba05b7 100644 --- a/web/app/components/base/chat/chat/type.ts +++ b/web/app/components/base/chat/chat/type.ts @@ -97,7 +97,11 @@ export type IChatItem = { // for agent log conversationId?: string input?: any - parentMessageId?: string + parentMessageId?: string | null + siblingCount?: number + siblingIndex?: number + prevSibling?: string + nextSibling?: string } export type Metadata = { diff --git a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx index ae4667306f..04f65b549c 100644 --- a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx +++ b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx @@ -42,6 +42,10 @@ const ChatWrapper = () => { return { ...config, + file_upload: { + ...(config as any).file_upload, + fileUploadConfig: (config as any).system_parameters, + }, supportFeedback: true, opening_statement: currentConversationId ? currentConversationItem?.introduction : (config as any).opening_statement, } as ChatConfig diff --git a/web/app/components/base/chat/embedded-chatbot/config-panel/form.tsx b/web/app/components/base/chat/embedded-chatbot/config-panel/form.tsx index 211907d48b..718b9a9d94 100644 --- a/web/app/components/base/chat/embedded-chatbot/config-panel/form.tsx +++ b/web/app/components/base/chat/embedded-chatbot/config-panel/form.tsx @@ -9,6 +9,7 @@ import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uplo const Form = () => { const { t } = useTranslation() const { + appParams, inputsForms, newConversationInputs, newConversationInputsRef, @@ -73,6 +74,7 @@ const Form = () => { allowed_file_extensions: form.allowed_file_extensions, allowed_file_upload_methods: form.allowed_file_upload_methods, number_limits: 1, + fileUploadConfig: (appParams as any).system_parameters, }} /> ) @@ -87,6 +89,7 @@ const Form = () => { allowed_file_extensions: form.allowed_file_extensions, allowed_file_upload_methods: form.allowed_file_upload_methods, number_limits: form.max_length, + fileUploadConfig: (appParams as any).system_parameters, }} /> ) diff --git a/web/app/components/base/chat/types.ts b/web/app/components/base/chat/types.ts index 402392ac2a..8d9dacdcd7 100644 --- a/web/app/components/base/chat/types.ts +++ b/web/app/components/base/chat/types.ts @@ -65,6 +65,10 @@ export type ChatItem = IChatItem & { allFiles?: FileEntity[] } +export type ChatItemInTree = { + children?: ChatItemInTree[] +} & IChatItem + export type OnSend = (message: string, files?: FileEntity[], last_answer?: ChatItem | null) => void export type OnRegenerate = (chatItem: ChatItem) => void diff --git a/web/app/components/base/chat/utils.ts b/web/app/components/base/chat/utils.ts index 3840f6a2b8..16357361cf 100644 --- a/web/app/components/base/chat/utils.ts +++ b/web/app/components/base/chat/utils.ts @@ -1,6 +1,7 @@ import { addFileInfos, sortAgentSorts } from '../../tools/utils' import { UUID_NIL } from './constants' -import type { ChatItem } from './types' +import type { IChatItem } from './chat/type' +import type { ChatItem, ChatItemInTree } from './types' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' async function decodeBase64AndDecompress(base64String: string) { @@ -23,7 +24,7 @@ function getProcessedInputsFromUrlParams(): Record { function getLastAnswer(chatList: ChatItem[]) { for (let i = chatList.length - 1; i >= 0; i--) { const item = chatList[i] - if (item.isAnswer && !item.isOpeningStatement) + if (item.isAnswer && !item.id.startsWith('answer-placeholder-') && !item.isOpeningStatement) return item } return null @@ -81,8 +82,131 @@ function getPrevChatList(fetchedMessages: any[]) { return ret.reverse() } +function buildChatItemTree(allMessages: IChatItem[]): ChatItemInTree[] { + const map: Record = {} + const rootNodes: ChatItemInTree[] = [] + const childrenCount: Record = {} + + let lastAppendedLegacyAnswer: ChatItemInTree | null = null + for (let i = 0; i < allMessages.length; i += 2) { + const question = allMessages[i]! + const answer = allMessages[i + 1]! + + const isLegacy = question.parentMessageId === UUID_NIL + const parentMessageId = isLegacy + ? (lastAppendedLegacyAnswer?.id || '') + : (question.parentMessageId || '') + + // Process question + childrenCount[parentMessageId] = (childrenCount[parentMessageId] || 0) + 1 + const questionNode: ChatItemInTree = { + ...question, + children: [], + } + map[question.id] = questionNode + + // Process answer + childrenCount[question.id] = 1 + const answerNode: ChatItemInTree = { + ...answer, + children: [], + siblingIndex: isLegacy ? 0 : childrenCount[parentMessageId] - 1, + } + map[answer.id] = answerNode + + // Connect question and answer + questionNode.children!.push(answerNode) + + // Append to parent or add to root + if (isLegacy) { + if (!lastAppendedLegacyAnswer) + rootNodes.push(questionNode) + else + lastAppendedLegacyAnswer.children!.push(questionNode) + + lastAppendedLegacyAnswer = answerNode + } + else { + if (!parentMessageId) + rootNodes.push(questionNode) + else + map[parentMessageId]?.children!.push(questionNode) + } + } + + return rootNodes +} + +function getThreadMessages(tree: ChatItemInTree[], targetMessageId?: string): ChatItemInTree[] { + let ret: ChatItemInTree[] = [] + let targetNode: ChatItemInTree | undefined + + // find path to the target message + const stack = tree.toReversed().map(rootNode => ({ + node: rootNode, + path: [rootNode], + })) + while (stack.length > 0) { + const { node, path } = stack.pop()! + if ( + node.id === targetMessageId + || (!targetMessageId && !node.children?.length && !stack.length) // if targetMessageId is not provided, we use the last message in the tree as the target + ) { + targetNode = node + ret = path.map((item, index) => { + if (!item.isAnswer) + return item + + const parentAnswer = path[index - 2] + const siblingCount = !parentAnswer ? tree.length : parentAnswer.children!.length + const prevSibling = !parentAnswer ? tree[item.siblingIndex! - 1]?.children?.[0]?.id : parentAnswer.children![item.siblingIndex! - 1]?.children?.[0].id + const nextSibling = !parentAnswer ? tree[item.siblingIndex! + 1]?.children?.[0]?.id : parentAnswer.children![item.siblingIndex! + 1]?.children?.[0].id + + return { ...item, siblingCount, prevSibling, nextSibling } + }) + break + } + if (node.children) { + for (let i = node.children.length - 1; i >= 0; i--) { + stack.push({ + node: node.children[i], + path: [...path, node.children[i]], + }) + } + } + } + + // append all descendant messages to the path + if (targetNode) { + const stack = [targetNode] + while (stack.length > 0) { + const node = stack.pop()! + if (node !== targetNode) + ret.push(node) + if (node.children?.length) { + const lastChild = node.children.at(-1)! + + if (!lastChild.isAnswer) { + stack.push(lastChild) + continue + } + + const parentAnswer = ret.at(-2) + const siblingCount = parentAnswer?.children?.length + const prevSibling = parentAnswer?.children?.at(-2)?.children?.[0]?.id + + stack.push({ ...lastChild, siblingCount, prevSibling }) + } + } + } + + return ret +} + export { getProcessedInputsFromUrlParams, - getLastAnswer, getPrevChatList, + getLastAnswer, + buildChatItemTree, + getThreadMessages, } diff --git a/web/app/components/base/features/types.ts b/web/app/components/base/features/types.ts index 3307f12cda..83f876383d 100644 --- a/web/app/components/base/features/types.ts +++ b/web/app/components/base/features/types.ts @@ -1,4 +1,5 @@ import type { Resolution, TransferMethod, TtsAutoPlay } from '@/types/app' +import type { FileUploadConfigResponse } from '@/models/common' export type EnabledOrDisabled = { enabled?: boolean @@ -38,6 +39,7 @@ export type FileUpload = { allowed_file_extensions?: string[] allowed_file_upload_methods?: TransferMethod[] number_limits?: number + fileUploadConfig?: FileUploadConfigResponse } & EnabledOrDisabled export type AnnotationReplyConfig = { diff --git a/web/app/components/base/file-uploader/constants.ts b/web/app/components/base/file-uploader/constants.ts index e6cc2995f9..629fe2566b 100644 --- a/web/app/components/base/file-uploader/constants.ts +++ b/web/app/components/base/file-uploader/constants.ts @@ -1,3 +1,7 @@ +// fallback for file size limit of dify_config +export const IMG_SIZE_LIMIT = 10 * 1024 * 1024 export const FILE_SIZE_LIMIT = 15 * 1024 * 1024 +export const AUDIO_SIZE_LIMIT = 50 * 1024 * 1024 +export const VIDEO_SIZE_LIMIT = 100 * 1024 * 1024 export const FILE_URL_REGEX = /^(https?|ftp):\/\// diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index 5e126a87b5..942e5d612a 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -14,19 +14,113 @@ import { getSupportFileType, isAllowedFileExtension, } from './utils' -import { FILE_SIZE_LIMIT } from './constants' +import { + AUDIO_SIZE_LIMIT, + FILE_SIZE_LIMIT, + IMG_SIZE_LIMIT, + VIDEO_SIZE_LIMIT, +} from '@/app/components/base/file-uploader/constants' import { useToastContext } from '@/app/components/base/toast' import { TransferMethod } from '@/types/app' import { SupportUploadFileTypes } from '@/app/components/workflow/types' import type { FileUpload } from '@/app/components/base/features/types' import { formatFileSize } from '@/utils/format' import { fetchRemoteFileInfo } from '@/service/common' +import type { FileUploadConfigResponse } from '@/models/common' + +export const useFileSizeLimit = (fileUploadConfig?: FileUploadConfigResponse) => { + const imgSizeLimit = Number(fileUploadConfig?.image_file_size_limit) * 1024 * 1024 || IMG_SIZE_LIMIT + const docSizeLimit = Number(fileUploadConfig?.file_size_limit) * 1024 * 1024 || FILE_SIZE_LIMIT + const audioSizeLimit = Number(fileUploadConfig?.audio_file_size_limit) * 1024 * 1024 || AUDIO_SIZE_LIMIT + const videoSizeLimit = Number(fileUploadConfig?.video_file_size_limit) * 1024 * 1024 || VIDEO_SIZE_LIMIT + + return { + imgSizeLimit, + docSizeLimit, + audioSizeLimit, + videoSizeLimit, + } +} export const useFile = (fileConfig: FileUpload) => { const { t } = useTranslation() const { notify } = useToastContext() const fileStore = useFileStore() const params = useParams() + const { imgSizeLimit, docSizeLimit, audioSizeLimit, videoSizeLimit } = useFileSizeLimit(fileConfig.fileUploadConfig) + + const checkSizeLimit = (fileType: string, fileSize: number) => { + switch (fileType) { + case SupportUploadFileTypes.image: { + if (fileSize > imgSizeLimit) { + notify({ + type: 'error', + message: t('common.fileUploader.uploadFromComputerLimit', { + type: SupportUploadFileTypes.image, + size: formatFileSize(imgSizeLimit), + }), + }) + return false + } + return true + } + case SupportUploadFileTypes.document: { + if (fileSize > docSizeLimit) { + notify({ + type: 'error', + message: t('common.fileUploader.uploadFromComputerLimit', { + type: SupportUploadFileTypes.document, + size: formatFileSize(docSizeLimit), + }), + }) + return false + } + return true + } + case SupportUploadFileTypes.audio: { + if (fileSize > audioSizeLimit) { + notify({ + type: 'error', + message: t('common.fileUploader.uploadFromComputerLimit', { + type: SupportUploadFileTypes.audio, + size: formatFileSize(audioSizeLimit), + }), + }) + return false + } + return true + } + case SupportUploadFileTypes.video: { + if (fileSize > videoSizeLimit) { + notify({ + type: 'error', + message: t('common.fileUploader.uploadFromComputerLimit', { + type: SupportUploadFileTypes.video, + size: formatFileSize(videoSizeLimit), + }), + }) + return false + } + return true + } + case SupportUploadFileTypes.custom: { + if (fileSize > docSizeLimit) { + notify({ + type: 'error', + message: t('common.fileUploader.uploadFromComputerLimit', { + type: SupportUploadFileTypes.document, + size: formatFileSize(docSizeLimit), + }), + }) + return false + } + return true + } + default: { + return true + } + } + } const handleAddFile = useCallback((newFile: FileEntity) => { const { @@ -117,12 +211,15 @@ export const useFile = (fileConfig: FileUpload) => { progress: 100, supportFileType: getSupportFileType(url, res.file_type, allowedFileTypes?.includes(SupportUploadFileTypes.custom)), } - handleUpdateFile(newFile) + if (!checkSizeLimit(newFile.supportFileType, newFile.size)) + handleRemoveFile(uploadingFile.id) + else + handleUpdateFile(newFile) }).catch(() => { notify({ type: 'error', message: t('common.fileUploader.pasteFileLinkInvalid') }) handleRemoveFile(uploadingFile.id) }) - }, [handleAddFile, handleUpdateFile, notify, t, handleRemoveFile, fileConfig?.allowed_file_types]) + }, [checkSizeLimit, handleAddFile, handleUpdateFile, notify, t, handleRemoveFile, fileConfig?.allowed_file_types]) const handleLoadFileFromLinkSuccess = useCallback(() => { }, []) @@ -140,13 +237,13 @@ export const useFile = (fileConfig: FileUpload) => { notify({ type: 'error', message: t('common.fileUploader.fileExtensionNotSupport') }) return } - if (file.size > FILE_SIZE_LIMIT) { - notify({ type: 'error', message: t('common.fileUploader.uploadFromComputerLimit', { size: formatFileSize(FILE_SIZE_LIMIT) }) }) + const allowedFileTypes = fileConfig.allowed_file_types + const fileType = getSupportFileType(file.name, file.type, allowedFileTypes?.includes(SupportUploadFileTypes.custom)) + if (!checkSizeLimit(fileType, file.size)) return - } + const reader = new FileReader() const isImage = file.type.startsWith('image') - const allowedFileTypes = fileConfig.allowed_file_types reader.addEventListener( 'load', @@ -187,7 +284,7 @@ export const useFile = (fileConfig: FileUpload) => { false, ) reader.readAsDataURL(file) - }, [notify, t, handleAddFile, handleUpdateFile, params.token, fileConfig?.allowed_file_types, fileConfig?.allowed_file_extensions]) + }, [checkSizeLimit, notify, t, handleAddFile, handleUpdateFile, params.token, fileConfig?.allowed_file_types, fileConfig?.allowed_file_extensions]) const handleClipboardPasteFile = useCallback((e: ClipboardEvent) => { const file = e.clipboardData?.files[0] diff --git a/web/app/components/base/markdown-blocks/button.tsx b/web/app/components/base/markdown-blocks/button.tsx new file mode 100644 index 0000000000..56647b3bbe --- /dev/null +++ b/web/app/components/base/markdown-blocks/button.tsx @@ -0,0 +1,22 @@ +import { useChatContext } from '@/app/components/base/chat/chat/context' +import Button from '@/app/components/base/button' +import cn from '@/utils/classnames' + +const MarkdownButton = ({ node }: any) => { + const { onSend } = useChatContext() + const variant = node.properties.dataVariant + const message = node.properties.dataMessage + const size = node.properties.dataSize + + return +} +MarkdownButton.displayName = 'MarkdownButton' + +export default MarkdownButton diff --git a/web/app/components/base/markdown-blocks/form.tsx b/web/app/components/base/markdown-blocks/form.tsx new file mode 100644 index 0000000000..f87f2dcd91 --- /dev/null +++ b/web/app/components/base/markdown-blocks/form.tsx @@ -0,0 +1,137 @@ +import Button from '@/app/components/base/button' +import Input from '@/app/components/base/input' +import Textarea from '@/app/components/base/textarea' +import { useChatContext } from '@/app/components/base/chat/chat/context' + +enum DATA_FORMAT { + TEXT = 'text', + JSON = 'json', +} +enum SUPPORTED_TAGS { + LABEL = 'label', + INPUT = 'input', + TEXTAREA = 'textarea', + BUTTON = 'button', +} +enum SUPPORTED_TYPES { + TEXT = 'text', + PASSWORD = 'password', + EMAIL = 'email', + NUMBER = 'number', +} +const MarkdownForm = ({ node }: any) => { + // const supportedTypes = ['text', 'password', 'email', 'number'] + //
+ // + // + // + // + // + // + // + //
+ const { onSend } = useChatContext() + + const getFormValues = (children: any) => { + const formValues: { [key: string]: any } = {} + children.forEach((child: any) => { + if (child.tagName === SUPPORTED_TAGS.INPUT) + formValues[child.properties.name] = child.properties.value + if (child.tagName === SUPPORTED_TAGS.TEXTAREA) + formValues[child.properties.name] = child.properties.value + }) + return formValues + } + const onSubmit = (e: any) => { + e.preventDefault() + const format = node.properties.dataFormat || DATA_FORMAT.TEXT + const result = getFormValues(node.children) + if (format === DATA_FORMAT.JSON) { + onSend?.(JSON.stringify(result)) + } + else { + const textResult = Object.entries(result) + .map(([key, value]) => `${key}: ${value}`) + .join('\n') + onSend?.(textResult) + } + } + return ( +
{ + e.preventDefault() + e.stopPropagation() + }} + > + {node.children.filter((i: any) => i.type === 'element').map((child: any, index: number) => { + if (child.tagName === SUPPORTED_TAGS.LABEL) { + return ( + + ) + } + if (child.tagName === SUPPORTED_TAGS.INPUT) { + if (Object.values(SUPPORTED_TYPES).includes(child.properties.type)) { + return ( + { + e.preventDefault() + child.properties.value = e.target.value + }} + /> + ) + } + else { + return

Unsupported input type: {child.properties.type}

+ } + } + if (child.tagName === SUPPORTED_TAGS.TEXTAREA) { + return ( +