mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Merge remote-tracking branch 'origin'
This commit is contained in:
commit
e4d5d9113f
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
|
@ -45,6 +45,10 @@ jobs:
|
|||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Ruff formatter check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry run -C api ruff format --check ./api
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||
|
|
52
.github/workflows/translate-i18n-base-on-english.yml
vendored
Normal file
52
.github/workflows/translate-i18n-base-on-english.yml
vendored
Normal file
|
@ -0,0 +1,52 @@
|
|||
name: Check i18n Files and Create PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
check-and-update:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
changed_files=$(git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} -- 'i18n/en-US/*.ts')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: npm run auto-gen-i18n
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
branch: chore/automated-i18n-updates
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -178,3 +178,4 @@ pyrightconfig.json
|
|||
api/.vscode
|
||||
|
||||
.idea/
|
||||
.vscode
|
|
@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
|||
|
||||
## Before you jump in
|
||||
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
|
||||
### Feature requests:
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
## 在开始之前
|
||||
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
|
||||
### 功能请求:
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは
|
|||
|
||||
## 飛び込む前に
|
||||
|
||||
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
|
||||
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
|
||||
|
||||
### 機能リクエスト
|
||||
|
||||
|
|
156
CONTRIBUTING_VI.md
Normal file
156
CONTRIBUTING_VI.md
Normal file
|
@ -0,0 +1,156 @@
|
|||
Thật tuyệt vời khi bạn muốn đóng góp cho Dify! Chúng tôi rất mong chờ được thấy những gì bạn sẽ làm. Là một startup với nguồn nhân lực và tài chính hạn chế, chúng tôi có tham vọng lớn là thiết kế quy trình trực quan nhất để xây dựng và quản lý các ứng dụng LLM. Mọi sự giúp đỡ từ cộng đồng đều rất quý giá đối với chúng tôi.
|
||||
|
||||
Chúng tôi cần linh hoạt và làm việc nhanh chóng, nhưng đồng thời cũng muốn đảm bảo các cộng tác viên như bạn có trải nghiệm đóng góp thuận lợi nhất có thể. Chúng tôi đã tạo ra hướng dẫn đóng góp này nhằm giúp bạn làm quen với codebase và cách chúng tôi làm việc với các cộng tác viên, để bạn có thể nhanh chóng bắt tay vào phần thú vị.
|
||||
|
||||
Hướng dẫn này, cũng như bản thân Dify, đang trong quá trình cải tiến liên tục. Chúng tôi rất cảm kích sự thông cảm của bạn nếu đôi khi nó không theo kịp dự án thực tế, và chúng tôi luôn hoan nghênh mọi phản hồi để cải thiện.
|
||||
|
||||
Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [Thỏa thuận Cấp phép và Đóng góp](./LICENSE) ngắn gọn của chúng tôi. Cộng đồng cũng tuân thủ [quy tắc ứng xử](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
## Trước khi bắt đầu
|
||||
|
||||
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
|
||||
|
||||
### Yêu cầu tính năng:
|
||||
|
||||
* Nếu bạn đang tạo một yêu cầu tính năng mới, chúng tôi muốn bạn giải thích tính năng đề xuất sẽ đạt được điều gì và cung cấp càng nhiều thông tin chi tiết càng tốt. [@perzeusss](https://github.com/perzeuss) đã tạo một [Trợ lý Yêu cầu Tính năng](https://udify.app/chat/MK2kVSnw1gakVwMX) rất hữu ích để giúp bạn soạn thảo nhu cầu của mình. Hãy thử dùng nó nhé.
|
||||
|
||||
* Nếu bạn muốn chọn một vấn đề từ danh sách hiện có, chỉ cần để lại bình luận dưới vấn đề đó nói rằng bạn sẽ làm.
|
||||
|
||||
Một thành viên trong nhóm làm việc trong lĩnh vực liên quan sẽ được thông báo. Nếu mọi thứ ổn, họ sẽ cho phép bạn bắt đầu code. Chúng tôi yêu cầu bạn chờ đợi cho đến lúc đó trước khi bắt tay vào làm tính năng, để không lãng phí công sức của bạn nếu chúng tôi đề xuất thay đổi.
|
||||
|
||||
Tùy thuộc vào lĩnh vực mà tính năng đề xuất thuộc về, bạn có thể nói chuyện với các thành viên khác nhau trong nhóm. Dưới đây là danh sách các lĩnh vực mà các thành viên trong nhóm chúng tôi đang làm việc hiện tại:
|
||||
|
||||
| Thành viên | Phạm vi |
|
||||
| ------------------------------------------------------------ | ---------------------------------------------------- |
|
||||
| [@yeuoly](https://github.com/Yeuoly) | Thiết kế kiến trúc Agents |
|
||||
| [@jyong](https://github.com/JohnJyong) | Thiết kế quy trình RAG |
|
||||
| [@GarfieldDai](https://github.com/GarfieldDai) | Xây dựng quy trình làm việc |
|
||||
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Làm cho giao diện người dùng dễ sử dụng |
|
||||
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Trải nghiệm nhà phát triển, đầu mối liên hệ cho mọi vấn đề |
|
||||
| [@takatost](https://github.com/takatost) | Định hướng và kiến trúc tổng thể sản phẩm |
|
||||
|
||||
Cách chúng tôi ưu tiên:
|
||||
|
||||
| Loại tính năng | Mức độ ưu tiên |
|
||||
| ------------------------------------------------------------ | -------------- |
|
||||
| Tính năng ưu tiên cao được gắn nhãn bởi thành viên trong nhóm | Ưu tiên cao |
|
||||
| Yêu cầu tính năng phổ biến từ [bảng phản hồi cộng đồng](https://github.com/langgenius/dify/discussions/categories/feedbacks) của chúng tôi | Ưu tiên trung bình |
|
||||
| Tính năng không quan trọng và cải tiến nhỏ | Ưu tiên thấp |
|
||||
| Có giá trị nhưng không cấp bách | Tính năng tương lai |
|
||||
|
||||
### Những vấn đề khác (ví dụ: báo cáo lỗi, tối ưu hiệu suất, sửa lỗi chính tả):
|
||||
|
||||
* Bắt đầu code ngay lập tức.
|
||||
|
||||
Cách chúng tôi ưu tiên:
|
||||
|
||||
| Loại vấn đề | Mức độ ưu tiên |
|
||||
| ------------------------------------------------------------ | -------------- |
|
||||
| Lỗi trong các chức năng chính (không thể đăng nhập, ứng dụng không hoạt động, lỗ hổng bảo mật) | Nghiêm trọng |
|
||||
| Lỗi không quan trọng, cải thiện hiệu suất | Ưu tiên trung bình |
|
||||
| Sửa lỗi nhỏ (lỗi chính tả, giao diện người dùng gây nhầm lẫn nhưng vẫn hoạt động) | Ưu tiên thấp |
|
||||
|
||||
|
||||
## Cài đặt
|
||||
|
||||
Dưới đây là các bước để thiết lập Dify cho việc phát triển:
|
||||
|
||||
### 1. Fork repository này
|
||||
|
||||
### 2. Clone repository
|
||||
|
||||
Clone repository đã fork từ terminal của bạn:
|
||||
|
||||
```
|
||||
git clone git@github.com:<tên_người_dùng_github>/dify.git
|
||||
```
|
||||
|
||||
### 3. Kiểm tra các phụ thuộc
|
||||
|
||||
Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đã được cài đặt trên hệ thống của bạn:
|
||||
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) phiên bản 3.10.x
|
||||
|
||||
### 4. Cài đặt
|
||||
|
||||
Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt.
|
||||
|
||||
Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/self-host-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục.
|
||||
|
||||
### 5. Truy cập Dify trong trình duyệt của bạn
|
||||
|
||||
Để xác nhận cài đặt của bạn, hãy truy cập [http://localhost:3000](http://localhost:3000) (địa chỉ mặc định, hoặc URL và cổng bạn đã cấu hình) trong trình duyệt. Bạn sẽ thấy Dify đang chạy.
|
||||
|
||||
## Phát triển
|
||||
|
||||
Nếu bạn đang thêm một nhà cung cấp mô hình, [hướng dẫn này](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md) dành cho bạn.
|
||||
|
||||
Nếu bạn đang thêm một nhà cung cấp công cụ cho Agent hoặc Workflow, [hướng dẫn này](./api/core/tools/README.md) dành cho bạn.
|
||||
|
||||
Để giúp bạn nhanh chóng định hướng phần đóng góp của mình, dưới đây là một bản phác thảo ngắn gọn về cấu trúc backend & frontend của Dify:
|
||||
|
||||
### Backend
|
||||
|
||||
Backend của Dify được viết bằng Python sử dụng [Flask](https://flask.palletsprojects.com/en/3.0.x/). Nó sử dụng [SQLAlchemy](https://www.sqlalchemy.org/) cho ORM và [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) cho hàng đợi tác vụ. Logic xác thực được thực hiện thông qua Flask-login.
|
||||
|
||||
```
|
||||
[api/]
|
||||
├── constants // Các cài đặt hằng số được sử dụng trong toàn bộ codebase.
|
||||
├── controllers // Định nghĩa các route API và logic xử lý yêu cầu.
|
||||
├── core // Điều phối ứng dụng cốt lõi, tích hợp mô hình và công cụ.
|
||||
├── docker // Cấu hình liên quan đến Docker & containerization.
|
||||
├── events // Xử lý và xử lý sự kiện
|
||||
├── extensions // Mở rộng với các framework/nền tảng bên thứ 3.
|
||||
├── fields // Định nghĩa trường cho serialization/marshalling.
|
||||
├── libs // Thư viện và tiện ích có thể tái sử dụng.
|
||||
├── migrations // Script cho việc di chuyển cơ sở dữ liệu.
|
||||
├── models // Mô hình cơ sở dữ liệu & định nghĩa schema.
|
||||
├── services // Xác định logic nghiệp vụ.
|
||||
├── storage // Lưu trữ khóa riêng tư.
|
||||
├── tasks // Xử lý các tác vụ bất đồng bộ và công việc nền.
|
||||
└── tests
|
||||
```
|
||||
|
||||
### Frontend
|
||||
|
||||
Website được khởi tạo trên boilerplate [Next.js](https://nextjs.org/) bằng Typescript và sử dụng [Tailwind CSS](https://tailwindcss.com/) cho styling. [React-i18next](https://react.i18next.com/) được sử dụng cho việc quốc tế hóa.
|
||||
|
||||
```
|
||||
[web/]
|
||||
├── app // layouts, pages và components
|
||||
│ ├── (commonLayout) // layout chung được sử dụng trong toàn bộ ứng dụng
|
||||
│ ├── (shareLayout) // layouts được chia sẻ cụ thể cho các phiên dựa trên token
|
||||
│ ├── activate // trang kích hoạt
|
||||
│ ├── components // được chia sẻ bởi các trang và layouts
|
||||
│ ├── install // trang cài đặt
|
||||
│ ├── signin // trang đăng nhập
|
||||
│ └── styles // styles được chia sẻ toàn cục
|
||||
├── assets // Tài nguyên tĩnh
|
||||
├── bin // scripts chạy ở bước build
|
||||
├── config // cài đặt và tùy chọn có thể điều chỉnh
|
||||
├── context // contexts được chia sẻ bởi các phần khác nhau của ứng dụng
|
||||
├── dictionaries // File dịch cho từng ngôn ngữ
|
||||
├── docker // cấu hình container
|
||||
├── hooks // Hooks có thể tái sử dụng
|
||||
├── i18n // Cấu hình quốc tế hóa
|
||||
├── models // mô tả các mô hình dữ liệu & hình dạng của phản hồi API
|
||||
├── public // tài nguyên meta như favicon
|
||||
├── service // xác định hình dạng của các hành động API
|
||||
├── test
|
||||
├── types // mô tả các tham số hàm và giá trị trả về
|
||||
└── utils // Các hàm tiện ích được chia sẻ
|
||||
```
|
||||
|
||||
## Gửi PR của bạn
|
||||
|
||||
Cuối cùng, đã đến lúc mở một pull request (PR) đến repository của chúng tôi. Đối với các tính năng lớn, chúng tôi sẽ merge chúng vào nhánh `deploy/dev` để kiểm tra trước khi đưa vào nhánh `main`. Nếu bạn gặp vấn đề như xung đột merge hoặc không biết cách mở pull request, hãy xem [hướng dẫn về pull request của GitHub](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
|
||||
|
||||
Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giới thiệu là một người đóng góp trong [README](https://github.com/langgenius/dify/blob/main/README.md) của chúng tôi.
|
||||
|
||||
## Nhận trợ giúp
|
||||
|
||||
Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng.
|
|
@ -60,7 +60,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
|
|||
ALIYUN_OSS_ENDPOINT=your-endpoint
|
||||
ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
|
||||
|
@ -247,8 +248,8 @@ API_TOOL_DEFAULT_READ_TIMEOUT=60
|
|||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT=600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 # 10MB
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # 1MB
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
|
||||
# Log file path
|
||||
LOG_FILE=
|
||||
|
@ -267,4 +268,13 @@ APP_MAX_ACTIVE_REQUESTS=0
|
|||
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
POSITION_TOOL_INCLUDES=
|
||||
POSITION_TOOL_EXCLUDES=
|
||||
|
||||
POSITION_PROVIDER_PINS=
|
||||
POSITION_PROVIDER_INCLUDES=
|
||||
POSITION_PROVIDER_EXCLUDES=
|
||||
|
|
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.7 KiB |
|
@ -12,5 +12,6 @@
|
|||
</component>
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
</project>
|
|
@ -5,8 +5,8 @@
|
|||
"name": "Python: Flask",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"envFile": ".env",
|
||||
"module": "flask",
|
||||
"justMyCode": true,
|
||||
|
@ -18,15 +18,15 @@
|
|||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001",
|
||||
"--port=5001"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Python: Celery",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"module": "celery",
|
||||
"justMyCode": true,
|
||||
"envFile": ".env",
|
|
@ -5,6 +5,10 @@ WORKDIR /app/api
|
|||
|
||||
# Install Poetry
|
||||
ENV POETRY_VERSION=1.8.3
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
|
||||
|
||||
# Configure Poetry
|
||||
|
@ -16,6 +20,9 @@ ENV POETRY_REQUESTS_TIMEOUT=15
|
|||
|
||||
FROM base AS packages
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
||||
|
@ -43,10 +50,12 @@ WORKDIR /app/api
|
|||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-2 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-2 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
@ -56,7 +65,7 @@ COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
|||
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
|
||||
# Download nltk data
|
||||
RUN python -c "import nltk; nltk.download('punkt')"
|
||||
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
|
||||
|
||||
# Copy source code
|
||||
COPY . /app/api/
|
||||
|
|
161
api/app.py
161
api/app.py
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != 'true':
|
||||
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
|
@ -57,7 +57,7 @@ warnings.simplefilter("ignore", ResourceWarning)
|
|||
if os.name == "nt":
|
||||
os.system('tzutil /s "UTC"')
|
||||
else:
|
||||
os.environ['TZ'] = 'UTC'
|
||||
os.environ["TZ"] = "UTC"
|
||||
time.tzset()
|
||||
|
||||
|
||||
|
@ -70,13 +70,14 @@ class DifyApp(Flask):
|
|||
# -------------
|
||||
|
||||
|
||||
config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
|
||||
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
# ----------------------------
|
||||
|
||||
|
||||
def create_flask_app_with_configs() -> Flask:
|
||||
"""
|
||||
create a raw flask app
|
||||
|
@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask:
|
|||
elif isinstance(value, int | float | bool):
|
||||
os.environ[key] = str(value)
|
||||
elif value is None:
|
||||
os.environ[key] = ''
|
||||
os.environ[key] = ""
|
||||
|
||||
return dify_app
|
||||
|
||||
|
@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask:
|
|||
def create_app() -> Flask:
|
||||
app = create_flask_app_with_configs()
|
||||
|
||||
app.secret_key = app.config['SECRET_KEY']
|
||||
app.secret_key = app.config["SECRET_KEY"]
|
||||
|
||||
log_handlers = None
|
||||
log_file = app.config.get('LOG_FILE')
|
||||
log_file = app.config.get("LOG_FILE")
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
@ -111,23 +112,24 @@ def create_app() -> Flask:
|
|||
RotatingFileHandler(
|
||||
filename=log_file,
|
||||
maxBytes=1024 * 1024 * 1024,
|
||||
backupCount=5
|
||||
backupCount=5,
|
||||
),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
logging.StreamHandler(sys.stdout),
|
||||
]
|
||||
|
||||
logging.basicConfig(
|
||||
level=app.config.get('LOG_LEVEL'),
|
||||
format=app.config.get('LOG_FORMAT'),
|
||||
datefmt=app.config.get('LOG_DATEFORMAT'),
|
||||
level=app.config.get("LOG_LEVEL"),
|
||||
format=app.config.get("LOG_FORMAT"),
|
||||
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||
handlers=log_handlers,
|
||||
force=True
|
||||
force=True,
|
||||
)
|
||||
log_tz = app.config.get('LOG_TZ')
|
||||
log_tz = app.config.get("LOG_TZ")
|
||||
if log_tz:
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
timezone = pytz.timezone(log_tz)
|
||||
|
||||
def time_converter(seconds):
|
||||
|
@ -162,24 +164,24 @@ def initialize_extensions(app):
|
|||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint not in ['console', 'inner_api']:
|
||||
if request.blueprint not in ["console", "inner_api"]:
|
||||
return None
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header:
|
||||
auth_token = request.args.get('_token')
|
||||
auth_token = request.args.get("_token")
|
||||
if not auth_token:
|
||||
raise Unauthorized('Invalid Authorization token.')
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
else:
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get('user_id')
|
||||
user_id = decoded.get("user_id")
|
||||
|
||||
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
||||
if account:
|
||||
|
@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login):
|
|||
@login_manager.unauthorized_handler
|
||||
def unauthorized_handler():
|
||||
"""Handle unauthorized requests."""
|
||||
return Response(json.dumps({
|
||||
'code': 'unauthorized',
|
||||
'message': "Unauthorized."
|
||||
}), status=401, content_type="application/json")
|
||||
return Response(
|
||||
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||
status=401,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
# register blueprint routers
|
||||
|
@ -204,38 +207,36 @@ def register_blueprints(app):
|
|||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.web import bp as web_bp
|
||||
|
||||
CORS(service_api_bp,
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
||||
)
|
||||
CORS(
|
||||
service_api_bp,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
CORS(web_bp,
|
||||
resources={
|
||||
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
|
||||
supports_credentials=True,
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
||||
expose_headers=['X-Version', 'X-Env']
|
||||
)
|
||||
CORS(
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
||||
app.register_blueprint(web_bp)
|
||||
|
||||
CORS(console_app_bp,
|
||||
resources={
|
||||
r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}},
|
||||
supports_credentials=True,
|
||||
allow_headers=['Content-Type', 'Authorization'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
||||
expose_headers=['X-Version', 'X-Env']
|
||||
)
|
||||
CORS(
|
||||
console_app_bp,
|
||||
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
||||
app.register_blueprint(console_app_bp)
|
||||
|
||||
CORS(files_bp,
|
||||
allow_headers=['Content-Type'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
||||
)
|
||||
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
app.register_blueprint(inner_api_bp)
|
||||
|
@ -245,29 +246,29 @@ def register_blueprints(app):
|
|||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if app.config.get('TESTING'):
|
||||
if app.config.get("TESTING"):
|
||||
print("App is running in TESTING mode")
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
"""Add Version headers to the response."""
|
||||
response.set_cookie('remember_token', '', expires=0)
|
||||
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
|
||||
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
|
||||
response.set_cookie("remember_token", "", expires=0)
|
||||
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
|
||||
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
|
||||
return response
|
||||
|
||||
|
||||
@app.route('/health')
|
||||
@app.route("/health")
|
||||
def health():
|
||||
return Response(json.dumps({
|
||||
'pid': os.getpid(),
|
||||
'status': 'ok',
|
||||
'version': app.config['CURRENT_VERSION']
|
||||
}), status=200, content_type="application/json")
|
||||
return Response(
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@app.route('/threads')
|
||||
@app.route("/threads")
|
||||
def threads():
|
||||
num_threads = threading.active_count()
|
||||
threads = threading.enumerate()
|
||||
|
@ -278,32 +279,34 @@ def threads():
|
|||
thread_id = thread.ident
|
||||
is_alive = thread.is_alive()
|
||||
|
||||
thread_list.append({
|
||||
'name': thread_name,
|
||||
'id': thread_id,
|
||||
'is_alive': is_alive
|
||||
})
|
||||
thread_list.append(
|
||||
{
|
||||
"name": thread_name,
|
||||
"id": thread_id,
|
||||
"is_alive": is_alive,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
'pid': os.getpid(),
|
||||
'thread_num': num_threads,
|
||||
'threads': thread_list
|
||||
"pid": os.getpid(),
|
||||
"thread_num": num_threads,
|
||||
"threads": thread_list,
|
||||
}
|
||||
|
||||
|
||||
@app.route('/db-pool-stat')
|
||||
@app.route("/db-pool-stat")
|
||||
def pool_stat():
|
||||
engine = db.engine
|
||||
return {
|
||||
'pid': os.getpid(),
|
||||
'pool_size': engine.pool.size(),
|
||||
'checked_in_connections': engine.pool.checkedin(),
|
||||
'checked_out_connections': engine.pool.checkedout(),
|
||||
'overflow_connections': engine.pool.overflow(),
|
||||
'connection_timeout': engine.pool.timeout(),
|
||||
'recycle_time': db.engine.pool._recycle
|
||||
"pid": os.getpid(),
|
||||
"pool_size": engine.pool.size(),
|
||||
"checked_in_connections": engine.pool.checkedin(),
|
||||
"checked_out_connections": engine.pool.checkedout(),
|
||||
"overflow_connections": engine.pool.overflow(),
|
||||
"connection_timeout": engine.pool.timeout(),
|
||||
"recycle_time": db.engine.pool._recycle,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0', port=5001)
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
|
424
api/commands.py
424
api/commands.py
|
@ -27,32 +27,29 @@ from models.provider import Provider, ProviderModel
|
|||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
|
||||
@click.option('--new-password', prompt=True, help='the new password.')
|
||||
@click.option('--password-confirm', prompt=True, help='the new password confirm.')
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
|
||||
@click.option("--new-password", prompt=True, help="the new password.")
|
||||
@click.option("--password-confirm", prompt=True, help="the new password confirm.")
|
||||
def reset_password(email, new_password, password_confirm):
|
||||
"""
|
||||
Reset password of owner account
|
||||
Only available in SELF_HOSTED mode
|
||||
"""
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
|
||||
click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
|
||||
return
|
||||
|
||||
account = db.session.query(Account). \
|
||||
filter(Account.email == email). \
|
||||
one_or_none()
|
||||
account = db.session.query(Account).filter(Account.email == email).one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
|
||||
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(
|
||||
click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red'))
|
||||
click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
|
||||
return
|
||||
|
||||
# generate password salt
|
||||
|
@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm):
|
|||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
|
||||
click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
|
||||
|
||||
|
||||
@click.command('reset-email', help='Reset the account email.')
|
||||
@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
|
||||
@click.option('--new-email', prompt=True, help='the new email.')
|
||||
@click.option('--email-confirm', prompt=True, help='the new email confirm.')
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
|
||||
@click.option("--new-email", prompt=True, help="the new email.")
|
||||
@click.option("--email-confirm", prompt=True, help="the new email confirm.")
|
||||
def reset_email(email, new_email, email_confirm):
|
||||
"""
|
||||
Replace account email
|
||||
:return:
|
||||
"""
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
|
||||
click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
|
||||
return
|
||||
|
||||
account = db.session.query(Account). \
|
||||
filter(Account.email == email). \
|
||||
one_or_none()
|
||||
account = db.session.query(Account).filter(Account.email == email).one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
|
||||
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(
|
||||
click.style('sorry. {} is not a valid email. '.format(email), fg='red'))
|
||||
click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
|
||||
click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
|
||||
|
||||
|
||||
@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. '
|
||||
'After the reset, all LLM credentials will become invalid, '
|
||||
'requiring re-entry.'
|
||||
'Only support SELF_HOSTED mode.')
|
||||
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
|
||||
' this operation cannot be rolled back!', fg='red'))
|
||||
@click.command(
|
||||
"reset-encrypt-key-pair",
|
||||
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
|
||||
"After the reset, all LLM credentials will become invalid, "
|
||||
"requiring re-entry."
|
||||
"Only support SELF_HOSTED mode.",
|
||||
)
|
||||
@click.confirmation_option(
|
||||
prompt=click.style(
|
||||
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
|
||||
)
|
||||
)
|
||||
def reset_encrypt_key_pair():
|
||||
"""
|
||||
Reset the encrypted key pair of workspace for encrypt LLM credentials.
|
||||
After the reset, all LLM credentials will become invalid, requiring re-entry.
|
||||
Only support SELF_HOSTED mode.
|
||||
"""
|
||||
if dify_config.EDITION != 'SELF_HOSTED':
|
||||
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
|
||||
return
|
||||
|
||||
tenants = db.session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
|
||||
click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
|
||||
return
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete()
|
||||
db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(click.style('Congratulations! '
|
||||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
||||
click.echo(
|
||||
click.style(
|
||||
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command('vdb-migrate', help='migrate vector db.')
|
||||
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
|
||||
@click.command("vdb-migrate", help="migrate vector db.")
|
||||
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
|
||||
def vdb_migrate(scope: str):
|
||||
if scope in ['knowledge', 'all']:
|
||||
if scope in ["knowledge", "all"]:
|
||||
migrate_knowledge_vector_database()
|
||||
if scope in ['annotation', 'all']:
|
||||
if scope in ["annotation", "all"]:
|
||||
migrate_annotation_vector_database()
|
||||
|
||||
|
||||
|
@ -146,7 +150,7 @@ def migrate_annotation_vector_database():
|
|||
"""
|
||||
Migrate annotation datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start migrate annotation data.', fg='green'))
|
||||
click.echo(click.style("Start migrate annotation data.", fg="green"))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
|
@ -154,98 +158,103 @@ def migrate_annotation_vector_database():
|
|||
while True:
|
||||
try:
|
||||
# get apps info
|
||||
apps = db.session.query(App).filter(
|
||||
App.status == 'normal'
|
||||
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
|
||||
apps = (
|
||||
db.session.query(App)
|
||||
.filter(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.paginate(page=page, per_page=50)
|
||||
)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for app in apps:
|
||||
total_count = total_count + 1
|
||||
click.echo(f'Processing the {total_count} app {app.id}. '
|
||||
+ f'{create_count} created, {skipped_count} skipped.')
|
||||
click.echo(
|
||||
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
|
||||
)
|
||||
try:
|
||||
click.echo('Create app annotation index: {}'.format(app.id))
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app.id
|
||||
).first()
|
||||
click.echo("Create app annotation index: {}".format(app.id))
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo('App annotation setting is disabled: {}'.format(app.id))
|
||||
click.echo("App annotation setting is disabled: {}".format(app.id))
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
|
||||
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
|
||||
).first()
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
|
||||
click.echo("App annotation collection binding is not exist: {}".format(app.id))
|
||||
continue
|
||||
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
documents = []
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question,
|
||||
metadata={
|
||||
"annotation_id": annotation.id,
|
||||
"app_id": app.id,
|
||||
"doc_id": annotation.id
|
||||
}
|
||||
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(f'Successfully delete vector index for app: {app.id}.',
|
||||
fg='green'))
|
||||
click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f'Failed to delete vector index for app {app.id}.',
|
||||
fg='red'))
|
||||
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
|
||||
raise e
|
||||
if documents:
|
||||
try:
|
||||
click.echo(click.style(
|
||||
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
|
||||
fg='green'))
|
||||
vector.create(documents)
|
||||
click.echo(
|
||||
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
|
||||
click.style(
|
||||
f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
vector.create(documents)
|
||||
click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
|
||||
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
|
||||
raise e
|
||||
click.echo(f'Successfully migrated app annotation {app.id}.')
|
||||
click.echo(f"Successfully migrated app annotation {app.id}.")
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.style(
|
||||
"Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
|
||||
fg='green'))
|
||||
click.style(
|
||||
f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start migrate vector db.', fg='green'))
|
||||
click.echo(click.style("Start migrate vector db.", fg="green"))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
|
@ -253,87 +262,77 @@ def migrate_knowledge_vector_database():
|
|||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
datasets = (
|
||||
db.session.query(Dataset)
|
||||
.filter(Dataset.indexing_technique == "high_quality")
|
||||
.order_by(Dataset.created_at.desc())
|
||||
.paginate(page=page, per_page=50)
|
||||
)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
total_count = total_count + 1
|
||||
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
|
||||
+ f'{create_count} created, {skipped_count} skipped.')
|
||||
click.echo(
|
||||
f"Processing the {total_count} dataset {dataset.id}. "
|
||||
+ f"{create_count} created, {skipped_count} skipped."
|
||||
)
|
||||
try:
|
||||
click.echo('Create dataset vdb index: {}'.format(dataset.id))
|
||||
click.echo("Create dataset vdb index: {}".format(dataset.id))
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] == vector_type:
|
||||
if dataset.index_struct_dict["type"] == vector_type:
|
||||
skipped_count = skipped_count + 1
|
||||
continue
|
||||
collection_name = ''
|
||||
collection_name = ""
|
||||
if vector_type == VectorType.WEAVIATE:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.WEAVIATE,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
raise ValueError("Dataset Collection Bindings is not exist!")
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.QDRANT,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
elif vector_type == VectorType.MILVUS:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.MILVUS,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.RELYT:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'relyt',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.TENCENT:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.TENCENT,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.PGVECTOR:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.PGVECTOR,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.OPENSEARCH:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.OPENSEARCH,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.ANALYTICDB:
|
||||
|
@ -341,16 +340,13 @@ def migrate_knowledge_vector_database():
|
|||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.ANALYTICDB,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.ELASTICSEARCH:
|
||||
dataset_id = dataset.id
|
||||
index_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'elasticsearch',
|
||||
"vector_store": {"class_prefix": index_name}
|
||||
}
|
||||
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
@ -361,29 +357,41 @@ def migrate_knowledge_vector_database():
|
|||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
|
||||
fg='green'))
|
||||
click.style(
|
||||
f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
|
||||
fg='red'))
|
||||
click.style(
|
||||
f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
documents = []
|
||||
segments_count = 0
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
|
@ -393,7 +401,7 @@ def migrate_knowledge_vector_database():
|
|||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
@ -401,37 +409,43 @@ def migrate_knowledge_vector_database():
|
|||
|
||||
if documents:
|
||||
try:
|
||||
click.echo(click.style(
|
||||
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
|
||||
fg='green'))
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
vector.create(documents)
|
||||
click.echo(
|
||||
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
|
||||
click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
|
||||
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
|
||||
raise e
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
click.echo(f'Successfully migrated dataset {dataset.id}.')
|
||||
click.echo(f"Successfully migrated dataset {dataset.id}.")
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
|
||||
)
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
|
||||
fg='green'))
|
||||
click.style(
|
||||
f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.')
|
||||
@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.")
|
||||
def convert_to_agent_apps():
|
||||
"""
|
||||
Convert Agent Assistant to Agent App.
|
||||
"""
|
||||
click.echo(click.style('Start convert to agent apps.', fg='green'))
|
||||
click.echo(click.style("Start convert to agent apps.", fg="green"))
|
||||
|
||||
proceeded_app_ids = []
|
||||
|
||||
|
@ -466,7 +480,7 @@ def convert_to_agent_apps():
|
|||
break
|
||||
|
||||
for app in apps:
|
||||
click.echo('Converting app: {}'.format(app.id))
|
||||
click.echo("Converting app: {}".format(app.id))
|
||||
|
||||
try:
|
||||
app.mode = AppMode.AGENT_CHAT.value
|
||||
|
@ -478,137 +492,142 @@ def convert_to_agent_apps():
|
|||
)
|
||||
|
||||
db.session.commit()
|
||||
click.echo(click.style('Converted app: {}'.format(app.id), fg='green'))
|
||||
click.echo(click.style("Converted app: {}".format(app.id), fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Convert app error: {} {}'.format(e.__class__.__name__,
|
||||
str(e)), fg='red'))
|
||||
click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
|
||||
|
||||
click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green'))
|
||||
click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
|
||||
|
||||
|
||||
@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.')
|
||||
@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.')
|
||||
@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
|
||||
@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
|
||||
def add_qdrant_doc_id_index(field: str):
|
||||
click.echo(click.style('Start add qdrant doc_id index.', fg='green'))
|
||||
click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
if vector_type != "qdrant":
|
||||
click.echo(click.style('Sorry, only support qdrant vector store.', fg='red'))
|
||||
click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
|
||||
return
|
||||
create_count = 0
|
||||
|
||||
try:
|
||||
bindings = db.session.query(DatasetCollectionBinding).all()
|
||||
if not bindings:
|
||||
click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red'))
|
||||
click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
import qdrant_client
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
|
||||
|
||||
for binding in bindings:
|
||||
if dify_config.QDRANT_URL is None:
|
||||
raise ValueError('Qdrant url is required.')
|
||||
raise ValueError("Qdrant url is required.")
|
||||
qdrant_config = QdrantConfig(
|
||||
endpoint=dify_config.QDRANT_URL,
|
||||
api_key=dify_config.QDRANT_API_KEY,
|
||||
root_path=current_app.root_path,
|
||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||
)
|
||||
try:
|
||||
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
|
||||
# create payload index
|
||||
client.create_payload_index(binding.collection_name, field,
|
||||
field_schema=PayloadSchemaType.KEYWORD)
|
||||
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
|
||||
create_count += 1
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red'))
|
||||
click.echo(
|
||||
click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
|
||||
)
|
||||
continue
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red'))
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(click.style('Failed to create qdrant client.', fg='red'))
|
||||
click.echo(click.style("Failed to create qdrant client.", fg="red"))
|
||||
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} collection indexes.',
|
||||
fg='green'))
|
||||
click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
|
||||
|
||||
|
||||
@click.command('create-tenant', help='Create account and tenant.')
|
||||
@click.option('--email', prompt=True, help='The email address of the tenant account.')
|
||||
@click.option('--language', prompt=True, help='Account language, default: en-US.')
|
||||
def create_tenant(email: str, language: Optional[str] = None):
|
||||
@click.command("create-tenant", help="Create account and tenant.")
|
||||
@click.option("--email", prompt=True, help="The email address of the tenant account.")
|
||||
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
|
||||
@click.option("--language", prompt=True, help="Account language, default: en-US.")
|
||||
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
|
||||
"""
|
||||
Create tenant account
|
||||
"""
|
||||
if not email:
|
||||
click.echo(click.style('Sorry, email is required.', fg='red'))
|
||||
click.echo(click.style("Sorry, email is required.", fg="red"))
|
||||
return
|
||||
|
||||
# Create account
|
||||
email = email.strip()
|
||||
|
||||
if '@' not in email:
|
||||
click.echo(click.style('Sorry, invalid email address.', fg='red'))
|
||||
if "@" not in email:
|
||||
click.echo(click.style("Sorry, invalid email address.", fg="red"))
|
||||
return
|
||||
|
||||
account_name = email.split('@')[0]
|
||||
account_name = email.split("@")[0]
|
||||
|
||||
if language not in languages:
|
||||
language = 'en-US'
|
||||
language = "en-US"
|
||||
|
||||
name = name.strip()
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(
|
||||
email=email,
|
||||
name=account_name,
|
||||
password=new_password,
|
||||
language=language
|
||||
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
click.echo(click.style('Congratulations! Account and tenant created.\n'
|
||||
'Account: {}\nPassword: {}'.format(email, new_password), fg='green'))
|
||||
|
||||
|
||||
@click.command('upgrade-db', help='upgrade the database')
|
||||
@click.command("upgrade-db", help="upgrade the database")
|
||||
def upgrade_db():
|
||||
click.echo('Preparing database migration...')
|
||||
lock = redis_client.lock(name='db_upgrade_lock', timeout=60)
|
||||
click.echo("Preparing database migration...")
|
||||
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
|
||||
if lock.acquire(blocking=False):
|
||||
try:
|
||||
click.echo(click.style('Start database migration.', fg='green'))
|
||||
click.echo(click.style("Start database migration.", fg="green"))
|
||||
|
||||
# run db migration
|
||||
import flask_migrate
|
||||
|
||||
flask_migrate.upgrade()
|
||||
|
||||
click.echo(click.style('Database migration successful!', fg='green'))
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f'Database migration failed, error: {e}')
|
||||
logging.exception(f"Database migration failed, error: {e}")
|
||||
finally:
|
||||
lock.release()
|
||||
else:
|
||||
click.echo('Database migration skipped')
|
||||
click.echo("Database migration skipped")
|
||||
|
||||
|
||||
@click.command('fix-app-site-missing', help='Fix app related site missing issue.')
|
||||
@click.command("fix-app-site-missing", help="Fix app related site missing issue.")
|
||||
def fix_app_site_missing():
|
||||
"""
|
||||
Fix app related site missing issue.
|
||||
"""
|
||||
click.echo(click.style('Start fix app related site missing issue.', fg='green'))
|
||||
click.echo(click.style("Start fix app related site missing issue.", fg="green"))
|
||||
|
||||
failed_app_ids = []
|
||||
while True:
|
||||
|
@ -639,15 +658,14 @@ where sites.id is null limit 1000"""
|
|||
app_was_created.send(app, account=account)
|
||||
except Exception as e:
|
||||
failed_app_ids.append(app_id)
|
||||
click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red'))
|
||||
logging.exception(f'Fix app related site missing issue failed, error: {e}')
|
||||
click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
|
||||
logging.exception(f"Fix app related site missing issue failed, error: {e}")
|
||||
continue
|
||||
|
||||
if not processed_count:
|
||||
break
|
||||
|
||||
|
||||
click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green'))
|
||||
click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .app_config import DifyConfig
|
||||
|
||||
dify_config = DifyConfig()
|
||||
dify_config = DifyConfig()
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from pydantic import Field, computed_field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from configs.deploy import DeploymentConfig
|
||||
|
@ -24,42 +23,16 @@ class DifyConfig(
|
|||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
):
|
||||
DEBUG: bool = Field(default=False, description='whether to enable debug mode.')
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
# read from dotenv format config file
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
frozen=True,
|
||||
# ignore extra attributes
|
||||
extra='ignore',
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: int = 9223372036854775807
|
||||
CODE_MIN_NUMBER: int = -9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH: int = 80000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10
|
||||
|
||||
@computed_field
|
||||
def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str:
|
||||
return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024
|
||||
|
||||
@computed_field
|
||||
def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str:
|
||||
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
|
||||
|
||||
SSRF_PROXY_HTTP_URL: str | None = None
|
||||
SSRF_PROXY_HTTPS_URL: str | None = None
|
||||
|
||||
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
|
||||
|
||||
MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.')
|
||||
# Before adding any config,
|
||||
# please consider to arrange it in the proper config group of existed or added
|
||||
# for better readability and maintainability.
|
||||
# Thanks for your concentration and consideration.
|
||||
|
|
|
@ -6,22 +6,28 @@ class DeploymentConfig(BaseSettings):
|
|||
"""
|
||||
Deployment configs
|
||||
"""
|
||||
|
||||
APPLICATION_NAME: str = Field(
|
||||
description='application name',
|
||||
default='langgenius/dify',
|
||||
description="application name",
|
||||
default="langgenius/dify",
|
||||
)
|
||||
|
||||
DEBUG: bool = Field(
|
||||
description="whether to enable debug mode.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
TESTING: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
EDITION: str = Field(
|
||||
description='deployment edition',
|
||||
default='SELF_HOSTED',
|
||||
description="deployment edition",
|
||||
default="SELF_HOSTED",
|
||||
)
|
||||
|
||||
DEPLOY_ENV: str = Field(
|
||||
description='deployment environment, default to PRODUCTION.',
|
||||
default='PRODUCTION',
|
||||
description="deployment environment, default to PRODUCTION.",
|
||||
default="PRODUCTION",
|
||||
)
|
||||
|
|
|
@ -7,13 +7,14 @@ class EnterpriseFeatureConfig(BaseSettings):
|
|||
Enterprise feature configs.
|
||||
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
"""
|
||||
|
||||
ENTERPRISE_ENABLED: bool = Field(
|
||||
description='whether to enable enterprise features.'
|
||||
'Before using, please contact business@dify.ai by email to inquire about licensing matters.',
|
||||
description="whether to enable enterprise features."
|
||||
"Before using, please contact business@dify.ai by email to inquire about licensing matters.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
CAN_REPLACE_LOGO: bool = Field(
|
||||
description='whether to allow replacing enterprise logo.',
|
||||
description="whether to allow replacing enterprise logo.",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
@ -8,27 +8,28 @@ class NotionConfig(BaseSettings):
|
|||
"""
|
||||
Notion integration configs
|
||||
"""
|
||||
|
||||
NOTION_CLIENT_ID: Optional[str] = Field(
|
||||
description='Notion client ID',
|
||||
description="Notion client ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_CLIENT_SECRET: Optional[str] = Field(
|
||||
description='Notion client secret key',
|
||||
description="Notion client secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTEGRATION_TYPE: Optional[str] = Field(
|
||||
description='Notion integration type, default to None, available values: internal.',
|
||||
description="Notion integration type, default to None, available values: internal.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTERNAL_SECRET: Optional[str] = Field(
|
||||
description='Notion internal secret key',
|
||||
description="Notion internal secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTEGRATION_TOKEN: Optional[str] = Field(
|
||||
description='Notion integration token',
|
||||
description="Notion integration token",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
@ -8,17 +8,18 @@ class SentryConfig(BaseSettings):
|
|||
"""
|
||||
Sentry configs
|
||||
"""
|
||||
|
||||
SENTRY_DSN: Optional[str] = Field(
|
||||
description='Sentry DSN',
|
||||
description="Sentry DSN",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field(
|
||||
description='Sentry trace sample rate',
|
||||
description="Sentry trace sample rate",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field(
|
||||
description='Sentry profiles sample rate',
|
||||
description="Sentry profiles sample rate",
|
||||
default=1.0,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.feature.hosted_service import HostedServiceConfig
|
||||
|
@ -10,16 +10,17 @@ class SecurityConfig(BaseSettings):
|
|||
"""
|
||||
Secret Key configs
|
||||
"""
|
||||
|
||||
SECRET_KEY: Optional[str] = Field(
|
||||
description='Your App secret key will be used for securely signing the session cookie'
|
||||
'Make sure you are changing this key for your deployment with a strong key.'
|
||||
'You can generate a strong key using `openssl rand -base64 42`.'
|
||||
'Alternatively you can set it with `SECRET_KEY` environment variable.',
|
||||
description="Your App secret key will be used for securely signing the session cookie"
|
||||
"Make sure you are changing this key for your deployment with a strong key."
|
||||
"You can generate a strong key using `openssl rand -base64 42`."
|
||||
"Alternatively you can set it with `SECRET_KEY` environment variable.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description='Expiry time in hours for reset token',
|
||||
description="Expiry time in hours for reset token",
|
||||
default=24,
|
||||
)
|
||||
|
||||
|
@ -28,12 +29,13 @@ class AppExecutionConfig(BaseSettings):
|
|||
"""
|
||||
App Execution configs
|
||||
"""
|
||||
|
||||
APP_MAX_EXECUTION_TIME: PositiveInt = Field(
|
||||
description='execution timeout in seconds for app execution',
|
||||
description="execution timeout in seconds for app execution",
|
||||
default=1200,
|
||||
)
|
||||
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description='max active request per app, 0 means unlimited',
|
||||
description="max active request per app, 0 means unlimited",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
@ -42,14 +44,70 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||
"""
|
||||
Code Execution Sandbox configs
|
||||
"""
|
||||
CODE_EXECUTION_ENDPOINT: str = Field(
|
||||
description='endpoint URL of code execution servcie',
|
||||
default='http://sandbox:8194',
|
||||
|
||||
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
|
||||
description="endpoint URL of code execution servcie",
|
||||
default="http://sandbox:8194",
|
||||
)
|
||||
|
||||
CODE_EXECUTION_API_KEY: str = Field(
|
||||
description='API key for code execution service',
|
||||
default='dify-sandbox',
|
||||
description="API key for code execution service",
|
||||
default="dify-sandbox",
|
||||
)
|
||||
|
||||
CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field(
|
||||
description="connect timeout in seconds for code execution request",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field(
|
||||
description="read timeout in seconds for code execution request",
|
||||
default=60.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field(
|
||||
description="write timeout in seconds for code execution request",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: PositiveInt = Field(
|
||||
description="max depth for code execution",
|
||||
default=9223372036854775807,
|
||||
)
|
||||
|
||||
CODE_MIN_NUMBER: NegativeInt = Field(
|
||||
description="",
|
||||
default=-9223372036854775807,
|
||||
)
|
||||
|
||||
CODE_MAX_DEPTH: PositiveInt = Field(
|
||||
description="max depth for code execution",
|
||||
default=5,
|
||||
)
|
||||
|
||||
CODE_MAX_PRECISION: PositiveInt = Field(
|
||||
description="max precision digits for float type in code execution",
|
||||
default=20,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
|
||||
description="max string length for code execution",
|
||||
default=80000,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
default=30,
|
||||
)
|
||||
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
default=30,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
default=1000,
|
||||
)
|
||||
|
||||
|
||||
|
@ -57,28 +115,27 @@ class EndpointConfig(BaseSettings):
|
|||
"""
|
||||
Module URL configs
|
||||
"""
|
||||
|
||||
CONSOLE_API_URL: str = Field(
|
||||
description='The backend URL prefix of the console API.'
|
||||
'used to concatenate the login authorization callback or notion integration callback.',
|
||||
default='',
|
||||
description="The backend URL prefix of the console API."
|
||||
"used to concatenate the login authorization callback or notion integration callback.",
|
||||
default="",
|
||||
)
|
||||
|
||||
CONSOLE_WEB_URL: str = Field(
|
||||
description='The front-end URL prefix of the console web.'
|
||||
'used to concatenate some front-end addresses and for CORS configuration use.',
|
||||
default='',
|
||||
description="The front-end URL prefix of the console web."
|
||||
"used to concatenate some front-end addresses and for CORS configuration use.",
|
||||
default="",
|
||||
)
|
||||
|
||||
SERVICE_API_URL: str = Field(
|
||||
description='Service API Url prefix.'
|
||||
'used to display Service API Base Url to the front-end.',
|
||||
default='',
|
||||
description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
|
||||
default="",
|
||||
)
|
||||
|
||||
APP_WEB_URL: str = Field(
|
||||
description='WebApp Url prefix.'
|
||||
'used to display WebAPP API Base Url to the front-end.',
|
||||
default='',
|
||||
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
|
@ -86,17 +143,18 @@ class FileAccessConfig(BaseSettings):
|
|||
"""
|
||||
File Access configs
|
||||
"""
|
||||
|
||||
FILES_URL: str = Field(
|
||||
description='File preview or download Url prefix.'
|
||||
' used to display File preview or download Url to the front-end or as Multi-model inputs;'
|
||||
'Url is signed and has expiration time.',
|
||||
validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'),
|
||||
description="File preview or download Url prefix."
|
||||
" used to display File preview or download Url to the front-end or as Multi-model inputs;"
|
||||
"Url is signed and has expiration time.",
|
||||
validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"),
|
||||
alias_priority=1,
|
||||
default='',
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
description='timeout in seconds for file accessing',
|
||||
description="timeout in seconds for file accessing",
|
||||
default=300,
|
||||
)
|
||||
|
||||
|
@ -105,23 +163,24 @@ class FileUploadConfig(BaseSettings):
|
|||
"""
|
||||
File Uploading configs
|
||||
"""
|
||||
|
||||
UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description='size limit in Megabytes for uploading files',
|
||||
description="size limit in Megabytes for uploading files",
|
||||
default=15,
|
||||
)
|
||||
|
||||
UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field(
|
||||
description='batch size limit for uploading files',
|
||||
description="batch size limit for uploading files",
|
||||
default=5,
|
||||
)
|
||||
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description='image file size limit in Megabytes for uploading files',
|
||||
description="image file size limit in Megabytes for uploading files",
|
||||
default=10,
|
||||
)
|
||||
|
||||
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
|
||||
description='', # todo: to be clarified
|
||||
description="", # todo: to be clarified
|
||||
default=20,
|
||||
)
|
||||
|
||||
|
@ -130,45 +189,79 @@ class HttpConfig(BaseSettings):
|
|||
"""
|
||||
HTTP configs
|
||||
"""
|
||||
|
||||
API_COMPRESSION_ENABLED: bool = Field(
|
||||
description='whether to enable HTTP response compression of gzip',
|
||||
description="whether to enable HTTP response compression of gzip",
|
||||
default=False,
|
||||
)
|
||||
|
||||
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description='',
|
||||
validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'),
|
||||
default='',
|
||||
description="",
|
||||
validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',')
|
||||
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description='',
|
||||
validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'),
|
||||
default='*',
|
||||
description="",
|
||||
validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"),
|
||||
default="*",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
|
||||
] = 10
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
|
||||
] = 60
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
|
||||
] = 20
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
default=10 * 1024 * 1024,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
default=1 * 1024 * 1024,
|
||||
)
|
||||
|
||||
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
|
||||
description="HTTP URL for SSRF proxy",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SSRF_PROXY_HTTPS_URL: Optional[str] = Field(
|
||||
description="HTTPS URL for SSRF proxy",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class InnerAPIConfig(BaseSettings):
|
||||
"""
|
||||
Inner API configs
|
||||
"""
|
||||
|
||||
INNER_API: bool = Field(
|
||||
description='whether to enable the inner API',
|
||||
description="whether to enable the inner API",
|
||||
default=False,
|
||||
)
|
||||
|
||||
INNER_API_KEY: Optional[str] = Field(
|
||||
description='The inner API key is used to authenticate the inner API',
|
||||
description="The inner API key is used to authenticate the inner API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
@ -179,28 +272,27 @@ class LoggingConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
LOG_LEVEL: str = Field(
|
||||
description='Log output level, default to INFO.'
|
||||
'It is recommended to set it to ERROR for production.',
|
||||
default='INFO',
|
||||
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
|
||||
default="INFO",
|
||||
)
|
||||
|
||||
LOG_FILE: Optional[str] = Field(
|
||||
description='logging output file path',
|
||||
description="logging output file path",
|
||||
default=None,
|
||||
)
|
||||
|
||||
LOG_FORMAT: str = Field(
|
||||
description='log format',
|
||||
default='%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s',
|
||||
description="log format",
|
||||
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
|
||||
)
|
||||
|
||||
LOG_DATEFORMAT: Optional[str] = Field(
|
||||
description='log date format',
|
||||
description="log date format",
|
||||
default=None,
|
||||
)
|
||||
|
||||
LOG_TZ: Optional[str] = Field(
|
||||
description='specify log timezone, eg: America/New_York',
|
||||
description="specify log timezone, eg: America/New_York",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
@ -209,8 +301,9 @@ class ModelLoadBalanceConfig(BaseSettings):
|
|||
"""
|
||||
Model load balance configs
|
||||
"""
|
||||
|
||||
MODEL_LB_ENABLED: bool = Field(
|
||||
description='whether to enable model load balancing',
|
||||
description="whether to enable model load balancing",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
@ -219,8 +312,9 @@ class BillingConfig(BaseSettings):
|
|||
"""
|
||||
Platform Billing Configurations
|
||||
"""
|
||||
|
||||
BILLING_ENABLED: bool = Field(
|
||||
description='whether to enable billing',
|
||||
description="whether to enable billing",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
@ -229,9 +323,10 @@ class UpdateConfig(BaseSettings):
|
|||
"""
|
||||
Update configs
|
||||
"""
|
||||
|
||||
CHECK_UPDATE_URL: str = Field(
|
||||
description='url for checking updates',
|
||||
default='https://updates.dify.ai',
|
||||
description="url for checking updates",
|
||||
default="https://updates.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
|
@ -241,47 +336,53 @@ class WorkflowConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field(
|
||||
description='max execution steps in single workflow execution',
|
||||
description="max execution steps in single workflow execution",
|
||||
default=500,
|
||||
)
|
||||
|
||||
WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field(
|
||||
description='max execution time in seconds in single workflow execution',
|
||||
description="max execution time in seconds in single workflow execution",
|
||||
default=1200,
|
||||
)
|
||||
|
||||
WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field(
|
||||
description='max depth of calling in single workflow execution',
|
||||
description="max depth of calling in single workflow execution",
|
||||
default=5,
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="The maximum size in bytes of a variable. default to 5KB.",
|
||||
default=5 * 1024,
|
||||
)
|
||||
|
||||
|
||||
class OAuthConfig(BaseSettings):
|
||||
"""
|
||||
oauth configs
|
||||
"""
|
||||
|
||||
OAUTH_REDIRECT_PATH: str = Field(
|
||||
description='redirect path for OAuth',
|
||||
default='/console/api/oauth/authorize',
|
||||
description="redirect path for OAuth",
|
||||
default="/console/api/oauth/authorize",
|
||||
)
|
||||
|
||||
GITHUB_CLIENT_ID: Optional[str] = Field(
|
||||
description='GitHub client id for OAuth',
|
||||
description="GitHub client id for OAuth",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GITHUB_CLIENT_SECRET: Optional[str] = Field(
|
||||
description='GitHub client secret key for OAuth',
|
||||
description="GitHub client secret key for OAuth",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_CLIENT_ID: Optional[str] = Field(
|
||||
description='Google client id for OAuth',
|
||||
description="Google client id for OAuth",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = Field(
|
||||
description='Google client secret key for OAuth',
|
||||
description="Google client secret key for OAuth",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
@ -291,9 +392,8 @@ class ModerationConfig(BaseSettings):
|
|||
Moderation in app configs.
|
||||
"""
|
||||
|
||||
# todo: to be clarified in usage and unit
|
||||
OUTPUT_MODERATION_BUFFER_SIZE: PositiveInt = Field(
|
||||
description='buffer size for moderation',
|
||||
MODERATION_BUFFER_SIZE: PositiveInt = Field(
|
||||
description="buffer size for moderation",
|
||||
default=300,
|
||||
)
|
||||
|
||||
|
@ -304,7 +404,7 @@ class ToolConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field(
|
||||
description='max age in seconds for tool icon caching',
|
||||
description="max age in seconds for tool icon caching",
|
||||
default=3600,
|
||||
)
|
||||
|
||||
|
@ -315,52 +415,52 @@ class MailConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
MAIL_TYPE: Optional[str] = Field(
|
||||
description='Mail provider type name, default to None, availabile values are `smtp` and `resend`.',
|
||||
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MAIL_DEFAULT_SEND_FROM: Optional[str] = Field(
|
||||
description='default email address for sending from ',
|
||||
description="default email address for sending from ",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESEND_API_KEY: Optional[str] = Field(
|
||||
description='API key for Resend',
|
||||
description="API key for Resend",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESEND_API_URL: Optional[str] = Field(
|
||||
description='API URL for Resend',
|
||||
description="API URL for Resend",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_SERVER: Optional[str] = Field(
|
||||
description='smtp server host',
|
||||
description="smtp server host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_PORT: Optional[int] = Field(
|
||||
description='smtp server port',
|
||||
description="smtp server port",
|
||||
default=465,
|
||||
)
|
||||
|
||||
SMTP_USERNAME: Optional[str] = Field(
|
||||
description='smtp server username',
|
||||
description="smtp server username",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_PASSWORD: Optional[str] = Field(
|
||||
description='smtp server password',
|
||||
description="smtp server password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_USE_TLS: bool = Field(
|
||||
description='whether to use TLS connection to smtp server',
|
||||
description="whether to use TLS connection to smtp server",
|
||||
default=False,
|
||||
)
|
||||
|
||||
SMTP_OPPORTUNISTIC_TLS: bool = Field(
|
||||
description='whether to use opportunistic TLS connection to smtp server',
|
||||
description="whether to use opportunistic TLS connection to smtp server",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
@ -371,22 +471,22 @@ class RagEtlConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
ETL_TYPE: str = Field(
|
||||
description='RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ',
|
||||
default='dify',
|
||||
description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
KEYWORD_DATA_SOURCE_TYPE: str = Field(
|
||||
description='source type for keyword data, default to `database`, available values are `database` .',
|
||||
default='database',
|
||||
description="source type for keyword data, default to `database`, available values are `database` .",
|
||||
default="database",
|
||||
)
|
||||
|
||||
UNSTRUCTURED_API_URL: Optional[str] = Field(
|
||||
description='API URL for Unstructured',
|
||||
description="API URL for Unstructured",
|
||||
default=None,
|
||||
)
|
||||
|
||||
UNSTRUCTURED_API_KEY: Optional[str] = Field(
|
||||
description='API key for Unstructured',
|
||||
description="API key for Unstructured",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
@ -397,22 +497,23 @@ class DataSetConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||
description='interval in days for cleaning up dataset',
|
||||
description="interval in days for cleaning up dataset",
|
||||
default=30,
|
||||
)
|
||||
|
||||
DATASET_OPERATOR_ENABLED: bool = Field(
|
||||
description='whether to enable dataset operator',
|
||||
description="whether to enable dataset operator",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
Workspace configs
|
||||
"""
|
||||
|
||||
INVITE_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description='workspaces invitation expiration in hours',
|
||||
description="workspaces invitation expiration in hours",
|
||||
default=72,
|
||||
)
|
||||
|
||||
|
@ -423,25 +524,81 @@ class IndexingConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
|
||||
description='max segmentation token length for indexing',
|
||||
description="max segmentation token length for indexing",
|
||||
default=1000,
|
||||
)
|
||||
|
||||
|
||||
class ImageFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
|
||||
description='multi model send image format, support base64, url, default is base64',
|
||||
default='base64',
|
||||
description="multi model send image format, support base64, url, default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
|
||||
class CeleryBeatConfig(BaseSettings):
|
||||
CELERY_BEAT_SCHEDULER_TIME: int = Field(
|
||||
description='the time of the celery scheduler, default to 1 day',
|
||||
description="the time of the celery scheduler, default to 1 day",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
POSITION_PROVIDER_PINS: str = Field(
|
||||
description="The heads of model providers",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_INCLUDES: str = Field(
|
||||
description="The included model providers",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_EXCLUDES: str = Field(
|
||||
description="The excluded model providers",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_TOOL_PINS: str = Field(
|
||||
description="The heads of tools",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_TOOL_INCLUDES: str = Field(
|
||||
description="The included tools",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_TOOL_EXCLUDES: str = Field(
|
||||
description="The excluded tools",
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""]
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""]
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
|
@ -466,7 +623,7 @@ class FeatureConfig(
|
|||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
WorkspaceConfig,
|
||||
|
||||
PositionConfig,
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
|
|
|
@ -10,62 +10,62 @@ class HostedOpenAiConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_API_BASE: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_TRIAL_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
|
||||
description='',
|
||||
default='gpt-3.5-turbo,'
|
||||
'gpt-3.5-turbo-1106,'
|
||||
'gpt-3.5-turbo-instruct,'
|
||||
'gpt-3.5-turbo-16k,'
|
||||
'gpt-3.5-turbo-16k-0613,'
|
||||
'gpt-3.5-turbo-0613,'
|
||||
'gpt-3.5-turbo-0125,'
|
||||
'text-davinci-003',
|
||||
description="",
|
||||
default="gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"text-davinci-003",
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=200,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_MODELS: str = Field(
|
||||
description='',
|
||||
default='gpt-4,'
|
||||
'gpt-4-turbo-preview,'
|
||||
'gpt-4-turbo-2024-04-09,'
|
||||
'gpt-4-1106-preview,'
|
||||
'gpt-4-0125-preview,'
|
||||
'gpt-3.5-turbo,'
|
||||
'gpt-3.5-turbo-16k,'
|
||||
'gpt-3.5-turbo-16k-0613,'
|
||||
'gpt-3.5-turbo-1106,'
|
||||
'gpt-3.5-turbo-0613,'
|
||||
'gpt-3.5-turbo-0125,'
|
||||
'gpt-3.5-turbo-instruct,'
|
||||
'text-davinci-003',
|
||||
description="",
|
||||
default="gpt-4,"
|
||||
"gpt-4-turbo-preview,"
|
||||
"gpt-4-turbo-2024-04-09,"
|
||||
"gpt-4-1106-preview,"
|
||||
"gpt-4-0125-preview,"
|
||||
"gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"text-davinci-003",
|
||||
)
|
||||
|
||||
|
||||
|
@ -75,22 +75,22 @@ class HostedAzureOpenAiConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_AZURE_OPENAI_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=200,
|
||||
)
|
||||
|
||||
|
@ -101,27 +101,27 @@ class HostedAnthropicConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=600000,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
@ -132,7 +132,7 @@ class HostedMinmaxConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_MINIMAX_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
@ -143,7 +143,7 @@ class HostedSparkConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_SPARK_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
@ -154,7 +154,7 @@ class HostedZhipuAIConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_ZHIPUAI_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
@ -165,13 +165,13 @@ class HostedModerationConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_MODERATION_ENABLED: bool = Field(
|
||||
description='',
|
||||
description="",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_MODERATION_PROVIDERS: str = Field(
|
||||
description='',
|
||||
default='',
|
||||
description="",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
|
@ -181,15 +181,15 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
|
||||
description='the mode for fetching app templates,'
|
||||
' default to remote,'
|
||||
' available values: remote, db, builtin',
|
||||
default='remote',
|
||||
description="the mode for fetching app templates,"
|
||||
" default to remote,"
|
||||
" available values: remote, db, builtin",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
description='the domain for fetching remote app templates',
|
||||
default='https://tmpl.dify.ai',
|
||||
description="the domain for fetching remote app templates",
|
||||
default="https://tmpl.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
|
@ -202,7 +202,6 @@ class HostedServiceConfig(
|
|||
HostedOpenAiConfig,
|
||||
HostedSparkConfig,
|
||||
HostedZhipuAIConfig,
|
||||
|
||||
# moderation
|
||||
HostedModerationConfig,
|
||||
):
|
||||
|
|
|
@ -13,6 +13,7 @@ from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
|||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||
|
@ -28,108 +29,108 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
|||
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: str = Field(
|
||||
description='storage type,'
|
||||
' default to `local`,'
|
||||
' available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.',
|
||||
default='local',
|
||||
description="storage type,"
|
||||
" default to `local`,"
|
||||
" available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.",
|
||||
default="local",
|
||||
)
|
||||
|
||||
STORAGE_LOCAL_PATH: str = Field(
|
||||
description='local storage path',
|
||||
default='storage',
|
||||
description="local storage path",
|
||||
default="storage",
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseSettings):
|
||||
VECTOR_STORE: Optional[str] = Field(
|
||||
description='vector store type',
|
||||
description="vector store type",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class KeywordStoreConfig(BaseSettings):
|
||||
KEYWORD_STORE: str = Field(
|
||||
description='keyword store type',
|
||||
default='jieba',
|
||||
description="keyword store type",
|
||||
default="jieba",
|
||||
)
|
||||
|
||||
|
||||
class DatabaseConfig:
|
||||
DB_HOST: str = Field(
|
||||
description='db host',
|
||||
default='localhost',
|
||||
description="db host",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
DB_PORT: PositiveInt = Field(
|
||||
description='db port',
|
||||
description="db port",
|
||||
default=5432,
|
||||
)
|
||||
|
||||
DB_USERNAME: str = Field(
|
||||
description='db username',
|
||||
default='postgres',
|
||||
description="db username",
|
||||
default="postgres",
|
||||
)
|
||||
|
||||
DB_PASSWORD: str = Field(
|
||||
description='db password',
|
||||
default='',
|
||||
description="db password",
|
||||
default="",
|
||||
)
|
||||
|
||||
DB_DATABASE: str = Field(
|
||||
description='db database',
|
||||
default='dify',
|
||||
description="db database",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
DB_CHARSET: str = Field(
|
||||
description='db charset',
|
||||
default='',
|
||||
description="db charset",
|
||||
default="",
|
||||
)
|
||||
|
||||
DB_EXTRAS: str = Field(
|
||||
description='db extras options. Example: keepalives_idle=60&keepalives=1',
|
||||
default='',
|
||||
description="db extras options. Example: keepalives_idle=60&keepalives=1",
|
||||
default="",
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
|
||||
description='db uri scheme',
|
||||
default='postgresql',
|
||||
description="db uri scheme",
|
||||
default="postgresql",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = (
|
||||
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}"
|
||||
if self.DB_CHARSET
|
||||
else self.DB_EXTRAS
|
||||
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS
|
||||
).strip("&")
|
||||
db_extras = f"?{db_extras}" if db_extras else ""
|
||||
return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
|
||||
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{db_extras}")
|
||||
return (
|
||||
f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
|
||||
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{db_extras}"
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
|
||||
description='pool size of SqlAlchemy',
|
||||
description="pool size of SqlAlchemy",
|
||||
default=30,
|
||||
)
|
||||
|
||||
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
|
||||
description='max overflows for SqlAlchemy',
|
||||
description="max overflows for SqlAlchemy",
|
||||
default=10,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
|
||||
description='SqlAlchemy pool recycle',
|
||||
description="SqlAlchemy pool recycle",
|
||||
default=3600,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING: bool = Field(
|
||||
description='whether to enable pool pre-ping in SqlAlchemy',
|
||||
description="whether to enable pool pre-ping in SqlAlchemy",
|
||||
default=False,
|
||||
)
|
||||
|
||||
SQLALCHEMY_ECHO: bool | str = Field(
|
||||
description='whether to enable SqlAlchemy echo',
|
||||
description="whether to enable SqlAlchemy echo",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
@ -137,35 +138,38 @@ class DatabaseConfig:
|
|||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
return {
|
||||
'pool_size': self.SQLALCHEMY_POOL_SIZE,
|
||||
'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW,
|
||||
'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE,
|
||||
'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING,
|
||||
'connect_args': {'options': '-c timezone=UTC'},
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
|
||||
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
|
||||
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
|
||||
"connect_args": {"options": "-c timezone=UTC"},
|
||||
}
|
||||
|
||||
|
||||
class CeleryConfig(DatabaseConfig):
|
||||
CELERY_BACKEND: str = Field(
|
||||
description='Celery backend, available values are `database`, `redis`',
|
||||
default='database',
|
||||
description="Celery backend, available values are `database`, `redis`",
|
||||
default="database",
|
||||
)
|
||||
|
||||
CELERY_BROKER_URL: Optional[str] = Field(
|
||||
description='CELERY_BROKER_URL',
|
||||
description="CELERY_BROKER_URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
|
||||
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
|
||||
return (
|
||||
"db+{}".format(self.SQLALCHEMY_DATABASE_URI)
|
||||
if self.CELERY_BACKEND == "database"
|
||||
else self.CELERY_BROKER_URL
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def BROKER_USE_SSL(self) -> bool:
|
||||
return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False
|
||||
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
|
||||
|
||||
|
||||
class MiddlewareConfig(
|
||||
|
@ -174,7 +178,6 @@ class MiddlewareConfig(
|
|||
DatabaseConfig,
|
||||
KeywordStoreConfig,
|
||||
RedisConfig,
|
||||
|
||||
# configs of storage and storage providers
|
||||
StorageConfig,
|
||||
AliyunOSSStorageConfig,
|
||||
|
@ -183,7 +186,6 @@ class MiddlewareConfig(
|
|||
TencentCloudCOSStorageConfig,
|
||||
S3StorageConfig,
|
||||
OCIStorageConfig,
|
||||
|
||||
# configs of vdb and vdb providers
|
||||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
|
@ -199,5 +201,6 @@ class MiddlewareConfig(
|
|||
TencentVectorDBConfig,
|
||||
TiDBVectorConfig,
|
||||
WeaviateConfig,
|
||||
ElasticsearchConfig,
|
||||
):
|
||||
pass
|
||||
|
|
15
api/configs/middleware/cache/redis_config.py
vendored
15
api/configs/middleware/cache/redis_config.py
vendored
|
@ -8,32 +8,33 @@ class RedisConfig(BaseSettings):
|
|||
"""
|
||||
Redis configs
|
||||
"""
|
||||
|
||||
REDIS_HOST: str = Field(
|
||||
description='Redis host',
|
||||
default='localhost',
|
||||
description="Redis host",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
REDIS_PORT: PositiveInt = Field(
|
||||
description='Redis port',
|
||||
description="Redis port",
|
||||
default=6379,
|
||||
)
|
||||
|
||||
REDIS_USERNAME: Optional[str] = Field(
|
||||
description='Redis username',
|
||||
description="Redis username",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_PASSWORD: Optional[str] = Field(
|
||||
description='Redis password',
|
||||
description="Redis password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_DB: NonNegativeInt = Field(
|
||||
description='Redis database id, default to 0',
|
||||
description="Redis database id, default to 0",
|
||||
default=0,
|
||||
)
|
||||
|
||||
REDIS_USE_SSL: bool = Field(
|
||||
description='whether to use SSL for Redis connection',
|
||||
description="whether to use SSL for Redis connection",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
@ -10,31 +10,36 @@ class AliyunOSSStorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field(
|
||||
description='Aliyun OSS bucket name',
|
||||
description="Aliyun OSS bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field(
|
||||
description='Aliyun OSS access key',
|
||||
description="Aliyun OSS access key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_SECRET_KEY: Optional[str] = Field(
|
||||
description='Aliyun OSS secret key',
|
||||
description="Aliyun OSS secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_ENDPOINT: Optional[str] = Field(
|
||||
description='Aliyun OSS endpoint URL',
|
||||
description="Aliyun OSS endpoint URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_REGION: Optional[str] = Field(
|
||||
description='Aliyun OSS region',
|
||||
description="Aliyun OSS region",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field(
|
||||
description='Aliyun OSS authentication version',
|
||||
description="Aliyun OSS authentication version",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_PATH: Optional[str] = Field(
|
||||
description="Aliyun OSS path",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
@ -10,36 +10,36 @@ class S3StorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
S3_ENDPOINT: Optional[str] = Field(
|
||||
description='S3 storage endpoint',
|
||||
description="S3 storage endpoint",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_REGION: Optional[str] = Field(
|
||||
description='S3 storage region',
|
||||
description="S3 storage region",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_BUCKET_NAME: Optional[str] = Field(
|
||||
description='S3 storage bucket name',
|
||||
description="S3 storage bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_ACCESS_KEY: Optional[str] = Field(
|
||||
description='S3 storage access key',
|
||||
description="S3 storage access key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_SECRET_KEY: Optional[str] = Field(
|
||||
description='S3 storage secret key',
|
||||
description="S3 storage secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_ADDRESS_STYLE: str = Field(
|
||||
description='S3 storage address style',
|
||||
default='auto',
|
||||
description="S3 storage address style",
|
||||
default="auto",
|
||||
)
|
||||
|
||||
S3_USE_AWS_MANAGED_IAM: bool = Field(
|
||||
description='whether to use aws managed IAM for S3',
|
||||
description="whether to use aws managed IAM for S3",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
@ -10,21 +10,21 @@ class AzureBlobStorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field(
|
||||
description='Azure Blob account name',
|
||||
description="Azure Blob account name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field(
|
||||
description='Azure Blob account key',
|
||||
description="Azure Blob account key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field(
|
||||
description='Azure Blob container name',
|
||||
description="Azure Blob container name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field(
|
||||
description='Azure Blob account URL',
|
||||
description="Azure Blob account URL",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
@ -10,11 +10,11 @@ class GoogleCloudStorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field(
|
||||
description='Google Cloud storage bucket name',
|
||||
description="Google Cloud storage bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field(
|
||||
description='Google Cloud storage service account json base64',
|
||||
description="Google Cloud storage service account json base64",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
@ -10,27 +10,26 @@ class OCIStorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
OCI_ENDPOINT: Optional[str] = Field(
|
||||
description='OCI storage endpoint',
|
||||
description="OCI storage endpoint",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_REGION: Optional[str] = Field(
|
||||
description='OCI storage region',
|
||||
description="OCI storage region",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_BUCKET_NAME: Optional[str] = Field(
|
||||
description='OCI storage bucket name',
|
||||
description="OCI storage bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_ACCESS_KEY: Optional[str] = Field(
|
||||
description='OCI storage access key',
|
||||
description="OCI storage access key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_SECRET_KEY: Optional[str] = Field(
|
||||
description='OCI storage secret key',
|
||||
description="OCI storage secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
|
|
@ -10,26 +10,26 @@ class TencentCloudCOSStorageConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
TENCENT_COS_BUCKET_NAME: Optional[str] = Field(
|
||||
description='Tencent Cloud COS bucket name',
|
||||
description="Tencent Cloud COS bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_REGION: Optional[str] = Field(
|
||||
description='Tencent Cloud COS region',
|
||||
description="Tencent Cloud COS region",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SECRET_ID: Optional[str] = Field(
|
||||
description='Tencent Cloud COS secret id',
|
||||
description="Tencent Cloud COS secret id",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SECRET_KEY: Optional[str] = Field(
|
||||
description='Tencent Cloud COS secret key',
|
||||
description="Tencent Cloud COS secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SCHEME: Optional[str] = Field(
|
||||
description='Tencent Cloud COS scheme',
|
||||
description="Tencent Cloud COS scheme",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
@ -10,35 +10,28 @@ class AnalyticdbConfig(BaseModel):
|
|||
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||
"""
|
||||
|
||||
ANALYTICDB_KEY_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
ANALYTICDB_KEY_ID: Optional[str] = Field(
|
||||
default=None, description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
)
|
||||
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
ANALYTICDB_KEY_SECRET: Optional[str] = Field(
|
||||
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
)
|
||||
ANALYTICDB_REGION_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
ANALYTICDB_REGION_ID: Optional[str] = Field(
|
||||
default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
)
|
||||
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
|
||||
ANALYTICDB_INSTANCE_ID: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..",
|
||||
)
|
||||
ANALYTICDB_ACCOUNT : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The account name used to log in to the AnalyticDB instance."
|
||||
ANALYTICDB_ACCOUNT: Optional[str] = Field(
|
||||
default=None, description="The account name used to log in to the AnalyticDB instance."
|
||||
)
|
||||
ANALYTICDB_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password associated with the AnalyticDB account for authentication."
|
||||
ANALYTICDB_PASSWORD: Optional[str] = Field(
|
||||
default=None, description="The password associated with the AnalyticDB account for authentication."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The namespace within AnalyticDB for schema isolation."
|
||||
ANALYTICDB_NAMESPACE: Optional[str] = Field(
|
||||
default=None, description="The namespace within AnalyticDB for schema isolation."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field(
|
||||
default=None, description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
)
|
||||
|
|
|
@ -10,31 +10,31 @@ class ChromaConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
CHROMA_HOST: Optional[str] = Field(
|
||||
description='Chroma host',
|
||||
description="Chroma host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_PORT: PositiveInt = Field(
|
||||
description='Chroma port',
|
||||
description="Chroma port",
|
||||
default=8000,
|
||||
)
|
||||
|
||||
CHROMA_TENANT: Optional[str] = Field(
|
||||
description='Chroma database',
|
||||
description="Chroma database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_DATABASE: Optional[str] = Field(
|
||||
description='Chroma database',
|
||||
description="Chroma database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_AUTH_PROVIDER: Optional[str] = Field(
|
||||
description='Chroma authentication provider',
|
||||
description="Chroma authentication provider",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
|
||||
description='Chroma authentication credentials',
|
||||
description="Chroma authentication credentials",
|
||||
default=None,
|
||||
)
|
||||
|
|
30
api/configs/middleware/vdb/elasticsearch_config.py
Normal file
30
api/configs/middleware/vdb/elasticsearch_config.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ElasticsearchConfig(BaseSettings):
|
||||
"""
|
||||
Elasticsearch configs
|
||||
"""
|
||||
|
||||
ELASTICSEARCH_HOST: Optional[str] = Field(
|
||||
description="Elasticsearch host",
|
||||
default="127.0.0.1",
|
||||
)
|
||||
|
||||
ELASTICSEARCH_PORT: PositiveInt = Field(
|
||||
description="Elasticsearch port",
|
||||
default=9200,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_USERNAME: Optional[str] = Field(
|
||||
description="Elasticsearch username",
|
||||
default="elastic",
|
||||
)
|
||||
|
||||
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
|
||||
description="Elasticsearch password",
|
||||
default="elastic",
|
||||
)
|
|
@ -10,31 +10,31 @@ class MilvusConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
MILVUS_HOST: Optional[str] = Field(
|
||||
description='Milvus host',
|
||||
description="Milvus host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_PORT: PositiveInt = Field(
|
||||
description='Milvus RestFul API port',
|
||||
description="Milvus RestFul API port",
|
||||
default=9091,
|
||||
)
|
||||
|
||||
MILVUS_USER: Optional[str] = Field(
|
||||
description='Milvus user',
|
||||
description="Milvus user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_PASSWORD: Optional[str] = Field(
|
||||
description='Milvus password',
|
||||
description="Milvus password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SECURE: bool = Field(
|
||||
description='whether to use SSL connection for Milvus',
|
||||
description="whether to use SSL connection for Milvus",
|
||||
default=False,
|
||||
)
|
||||
|
||||
MILVUS_DATABASE: str = Field(
|
||||
description='Milvus database, default to `default`',
|
||||
default='default',
|
||||
description="Milvus database, default to `default`",
|
||||
default="default",
|
||||
)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
|
@ -8,31 +7,31 @@ class MyScaleConfig(BaseModel):
|
|||
"""
|
||||
|
||||
MYSCALE_HOST: str = Field(
|
||||
description='MyScale host',
|
||||
default='localhost',
|
||||
description="MyScale host",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
MYSCALE_PORT: PositiveInt = Field(
|
||||
description='MyScale port',
|
||||
description="MyScale port",
|
||||
default=8123,
|
||||
)
|
||||
|
||||
MYSCALE_USER: str = Field(
|
||||
description='MyScale user',
|
||||
default='default',
|
||||
description="MyScale user",
|
||||
default="default",
|
||||
)
|
||||
|
||||
MYSCALE_PASSWORD: str = Field(
|
||||
description='MyScale password',
|
||||
default='',
|
||||
description="MyScale password",
|
||||
default="",
|
||||
)
|
||||
|
||||
MYSCALE_DATABASE: str = Field(
|
||||
description='MyScale database name',
|
||||
default='default',
|
||||
description="MyScale database name",
|
||||
default="default",
|
||||
)
|
||||
|
||||
MYSCALE_FTS_PARAMS: str = Field(
|
||||
description='MyScale fts index parameters',
|
||||
default='',
|
||||
description="MyScale fts index parameters",
|
||||
default="",
|
||||
)
|
||||
|
|
|
@ -10,26 +10,26 @@ class OpenSearchConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
OPENSEARCH_HOST: Optional[str] = Field(
|
||||
description='OpenSearch host',
|
||||
description="OpenSearch host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_PORT: PositiveInt = Field(
|
||||
description='OpenSearch port',
|
||||
description="OpenSearch port",
|
||||
default=9200,
|
||||
)
|
||||
|
||||
OPENSEARCH_USER: Optional[str] = Field(
|
||||
description='OpenSearch user',
|
||||
description="OpenSearch user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_PASSWORD: Optional[str] = Field(
|
||||
description='OpenSearch password',
|
||||
description="OpenSearch password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_SECURE: bool = Field(
|
||||
description='whether to use SSL connection for OpenSearch',
|
||||
description="whether to use SSL connection for OpenSearch",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
@ -10,26 +10,26 @@ class OracleConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
ORACLE_HOST: Optional[str] = Field(
|
||||
description='ORACLE host',
|
||||
description="ORACLE host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_PORT: Optional[PositiveInt] = Field(
|
||||
description='ORACLE port',
|
||||
description="ORACLE port",
|
||||
default=1521,
|
||||
)
|
||||
|
||||
ORACLE_USER: Optional[str] = Field(
|
||||
description='ORACLE user',
|
||||
description="ORACLE user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_PASSWORD: Optional[str] = Field(
|
||||
description='ORACLE password',
|
||||
description="ORACLE password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_DATABASE: Optional[str] = Field(
|
||||
description='ORACLE database',
|
||||
description="ORACLE database",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
@ -10,26 +10,26 @@ class PGVectorConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
PGVECTOR_HOST: Optional[str] = Field(
|
||||
description='PGVector host',
|
||||
description="PGVector host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
description='PGVector port',
|
||||
description="PGVector port",
|
||||
default=5433,
|
||||
)
|
||||
|
||||
PGVECTOR_USER: Optional[str] = Field(
|
||||
description='PGVector user',
|
||||
description="PGVector user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_PASSWORD: Optional[str] = Field(
|
||||
description='PGVector password',
|
||||
description="PGVector password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_DATABASE: Optional[str] = Field(
|
||||
description='PGVector database',
|
||||
description="PGVector database",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
@ -10,26 +10,26 @@ class PGVectoRSConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
PGVECTO_RS_HOST: Optional[str] = Field(
|
||||
description='PGVectoRS host',
|
||||
description="PGVectoRS host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
|
||||
description='PGVectoRS port',
|
||||
description="PGVectoRS port",
|
||||
default=5431,
|
||||
)
|
||||
|
||||
PGVECTO_RS_USER: Optional[str] = Field(
|
||||
description='PGVectoRS user',
|
||||
description="PGVectoRS user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_PASSWORD: Optional[str] = Field(
|
||||
description='PGVectoRS password',
|
||||
description="PGVectoRS password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_DATABASE: Optional[str] = Field(
|
||||
description='PGVectoRS database',
|
||||
description="PGVectoRS database",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
@ -10,26 +10,26 @@ class QdrantConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
QDRANT_URL: Optional[str] = Field(
|
||||
description='Qdrant url',
|
||||
description="Qdrant url",
|
||||
default=None,
|
||||
)
|
||||
|
||||
QDRANT_API_KEY: Optional[str] = Field(
|
||||
description='Qdrant api key',
|
||||
description="Qdrant api key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
|
||||
description='Qdrant client timeout in seconds',
|
||||
description="Qdrant client timeout in seconds",
|
||||
default=20,
|
||||
)
|
||||
|
||||
QDRANT_GRPC_ENABLED: bool = Field(
|
||||
description='whether enable grpc support for Qdrant connection',
|
||||
description="whether enable grpc support for Qdrant connection",
|
||||
default=False,
|
||||
)
|
||||
|
||||
QDRANT_GRPC_PORT: PositiveInt = Field(
|
||||
description='Qdrant grpc port',
|
||||
description="Qdrant grpc port",
|
||||
default=6334,
|
||||
)
|
||||
|
|
|
@ -10,26 +10,26 @@ class RelytConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
RELYT_HOST: Optional[str] = Field(
|
||||
description='Relyt host',
|
||||
description="Relyt host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_PORT: PositiveInt = Field(
|
||||
description='Relyt port',
|
||||
description="Relyt port",
|
||||
default=9200,
|
||||
)
|
||||
|
||||
RELYT_USER: Optional[str] = Field(
|
||||
description='Relyt user',
|
||||
description="Relyt user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_PASSWORD: Optional[str] = Field(
|
||||
description='Relyt password',
|
||||
description="Relyt password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_DATABASE: Optional[str] = Field(
|
||||
description='Relyt database',
|
||||
default='default',
|
||||
description="Relyt database",
|
||||
default="default",
|
||||
)
|
||||
|
|
|
@ -10,41 +10,41 @@ class TencentVectorDBConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
TENCENT_VECTOR_DB_URL: Optional[str] = Field(
|
||||
description='Tencent Vector URL',
|
||||
description="Tencent Vector URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
|
||||
description='Tencent Vector API key',
|
||||
description="Tencent Vector API key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
|
||||
description='Tencent Vector timeout in seconds',
|
||||
description="Tencent Vector timeout in seconds",
|
||||
default=30,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
|
||||
description='Tencent Vector username',
|
||||
description="Tencent Vector username",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
|
||||
description='Tencent Vector password',
|
||||
description="Tencent Vector password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
|
||||
description='Tencent Vector sharding number',
|
||||
description="Tencent Vector sharding number",
|
||||
default=1,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||
description='Tencent Vector replicas',
|
||||
description="Tencent Vector replicas",
|
||||
default=2,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||
description='Tencent Vector Database',
|
||||
description="Tencent Vector Database",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
@ -10,26 +10,26 @@ class TiDBVectorConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
TIDB_VECTOR_HOST: Optional[str] = Field(
|
||||
description='TiDB Vector host',
|
||||
description="TiDB Vector host",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
description='TiDB Vector port',
|
||||
description="TiDB Vector port",
|
||||
default=4000,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_USER: Optional[str] = Field(
|
||||
description='TiDB Vector user',
|
||||
description="TiDB Vector user",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_PASSWORD: Optional[str] = Field(
|
||||
description='TiDB Vector password',
|
||||
description="TiDB Vector password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_DATABASE: Optional[str] = Field(
|
||||
description='TiDB Vector database',
|
||||
description="TiDB Vector database",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
@ -10,21 +10,21 @@ class WeaviateConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
WEAVIATE_ENDPOINT: Optional[str] = Field(
|
||||
description='Weaviate endpoint URL',
|
||||
description="Weaviate endpoint URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_API_KEY: Optional[str] = Field(
|
||||
description='Weaviate API key',
|
||||
description="Weaviate API key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENABLED: bool = Field(
|
||||
description='whether to enable gRPC for Weaviate connection',
|
||||
description="whether to enable gRPC for Weaviate connection",
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||
description='Weaviate batch size',
|
||||
description="Weaviate batch size",
|
||||
default=100,
|
||||
)
|
||||
|
|
|
@ -8,11 +8,11 @@ class PackagingInfo(BaseSettings):
|
|||
"""
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description='Dify version',
|
||||
default='0.7.0',
|
||||
description="Dify version",
|
||||
default="0.7.3",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
description="SHA-1 checksum of the git commit used to build the app",
|
||||
default='',
|
||||
default="",
|
||||
)
|
||||
|
|
|
@ -1 +1 @@
|
|||
HIDDEN_VALUE = '[__HIDDEN__]'
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
|
|
|
@ -1,22 +1,22 @@
|
|||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
'zh-Hans': 'Asia/Shanghai',
|
||||
'zh-Hant': 'Asia/Taipei',
|
||||
'pt-BR': 'America/Sao_Paulo',
|
||||
'es-ES': 'Europe/Madrid',
|
||||
'fr-FR': 'Europe/Paris',
|
||||
'de-DE': 'Europe/Berlin',
|
||||
'ja-JP': 'Asia/Tokyo',
|
||||
'ko-KR': 'Asia/Seoul',
|
||||
'ru-RU': 'Europe/Moscow',
|
||||
'it-IT': 'Europe/Rome',
|
||||
'uk-UA': 'Europe/Kyiv',
|
||||
'vi-VN': 'Asia/Ho_Chi_Minh',
|
||||
'ro-RO': 'Europe/Bucharest',
|
||||
'pl-PL': 'Europe/Warsaw',
|
||||
'hi-IN': 'Asia/Kolkata',
|
||||
'tr-TR': 'Europe/Istanbul',
|
||||
'fa-IR': 'Asia/Tehran',
|
||||
"en-US": "America/New_York",
|
||||
"zh-Hans": "Asia/Shanghai",
|
||||
"zh-Hant": "Asia/Taipei",
|
||||
"pt-BR": "America/Sao_Paulo",
|
||||
"es-ES": "Europe/Madrid",
|
||||
"fr-FR": "Europe/Paris",
|
||||
"de-DE": "Europe/Berlin",
|
||||
"ja-JP": "Asia/Tokyo",
|
||||
"ko-KR": "Asia/Seoul",
|
||||
"ru-RU": "Europe/Moscow",
|
||||
"it-IT": "Europe/Rome",
|
||||
"uk-UA": "Europe/Kyiv",
|
||||
"vi-VN": "Asia/Ho_Chi_Minh",
|
||||
"ro-RO": "Europe/Bucharest",
|
||||
"pl-PL": "Europe/Warsaw",
|
||||
"hi-IN": "Asia/Kolkata",
|
||||
"tr-TR": "Europe/Istanbul",
|
||||
"fa-IR": "Asia/Tehran",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
@ -26,6 +26,5 @@ def supported_language(lang):
|
|||
if lang in languages:
|
||||
return lang
|
||||
|
||||
error = ('{lang} is not a valid language.'
|
||||
.format(lang=lang))
|
||||
error = "{lang} is not a valid language.".format(lang=lang)
|
||||
raise ValueError(error)
|
||||
|
|
|
@ -5,82 +5,79 @@ from models.model import AppMode
|
|||
default_app_templates = {
|
||||
# workflow default mode
|
||||
AppMode.WORKFLOW: {
|
||||
'app': {
|
||||
'mode': AppMode.WORKFLOW.value,
|
||||
'enable_site': True,
|
||||
'enable_api': True
|
||||
"app": {
|
||||
"mode": AppMode.WORKFLOW.value,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
}
|
||||
},
|
||||
|
||||
# completion default mode
|
||||
AppMode.COMPLETION: {
|
||||
'app': {
|
||||
'mode': AppMode.COMPLETION.value,
|
||||
'enable_site': True,
|
||||
'enable_api': True
|
||||
"app": {
|
||||
"mode": AppMode.COMPLETION.value,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
'model_config': {
|
||||
'model': {
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {}
|
||||
"completion_params": {},
|
||||
},
|
||||
'user_input_form': json.dumps([
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
]),
|
||||
'pre_prompt': '{{query}}'
|
||||
"user_input_form": json.dumps(
|
||||
[
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": "",
|
||||
},
|
||||
},
|
||||
]
|
||||
),
|
||||
"pre_prompt": "{{query}}",
|
||||
},
|
||||
|
||||
},
|
||||
|
||||
# chat default mode
|
||||
AppMode.CHAT: {
|
||||
'app': {
|
||||
'mode': AppMode.CHAT.value,
|
||||
'enable_site': True,
|
||||
'enable_api': True
|
||||
"app": {
|
||||
"mode": AppMode.CHAT.value,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
'model_config': {
|
||||
'model': {
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {}
|
||||
}
|
||||
}
|
||||
"completion_params": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
# advanced-chat default mode
|
||||
AppMode.ADVANCED_CHAT: {
|
||||
'app': {
|
||||
'mode': AppMode.ADVANCED_CHAT.value,
|
||||
'enable_site': True,
|
||||
'enable_api': True
|
||||
}
|
||||
"app": {
|
||||
"mode": AppMode.ADVANCED_CHAT.value,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
},
|
||||
|
||||
# agent-chat default mode
|
||||
AppMode.AGENT_CHAT: {
|
||||
'app': {
|
||||
'mode': AppMode.AGENT_CHAT.value,
|
||||
'enable_site': True,
|
||||
'enable_api': True
|
||||
"app": {
|
||||
"mode": AppMode.AGENT_CHAT.value,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
'model_config': {
|
||||
'model': {
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
"completion_params": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
from contextvars import ContextVar
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar('tenant_id')
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
|
||||
|
|
|
@ -1,3 +1 @@
|
|||
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from flask import Blueprint
|
|||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('console', __name__, url_prefix='/console/api')
|
||||
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# Import other controllers
|
||||
|
|
|
@ -15,24 +15,24 @@ from models.model import App, InstalledApp, RecommendedApp
|
|||
def admin_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not os.getenv('ADMIN_API_KEY'):
|
||||
raise Unauthorized('API key is invalid.')
|
||||
if not os.getenv("ADMIN_API_KEY"):
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
auth_header = request.headers.get('Authorization')
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None:
|
||||
raise Unauthorized('Authorization header is missing.')
|
||||
raise Unauthorized("Authorization header is missing.")
|
||||
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if os.getenv('ADMIN_API_KEY') != auth_token:
|
||||
raise Unauthorized('API key is invalid.')
|
||||
if os.getenv("ADMIN_API_KEY") != auth_token:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
@ -44,37 +44,41 @@ class InsertExploreAppListApi(Resource):
|
|||
@admin_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('app_id', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('desc', type=str, location='json')
|
||||
parser.add_argument('copyright', type=str, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, location='json')
|
||||
parser.add_argument('custom_disclaimer', type=str, location='json')
|
||||
parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json')
|
||||
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('position', type=int, required=True, nullable=False, location='json')
|
||||
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("desc", type=str, location="json")
|
||||
parser.add_argument("copyright", type=str, location="json")
|
||||
parser.add_argument("privacy_policy", type=str, location="json")
|
||||
parser.add_argument("custom_disclaimer", type=str, location="json")
|
||||
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = App.query.filter(App.id == args['app_id']).first()
|
||||
app = App.query.filter(App.id == args["app_id"]).first()
|
||||
if not app:
|
||||
raise NotFound(f'App \'{args["app_id"]}\' is not found')
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = args['desc'] if args['desc'] else ''
|
||||
copy_right = args['copyright'] if args['copyright'] else ''
|
||||
privacy_policy = args['privacy_policy'] if args['privacy_policy'] else ''
|
||||
custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else ''
|
||||
desc = args["desc"] if args["desc"] else ""
|
||||
copy_right = args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
|
||||
else:
|
||||
desc = site.description if site.description else \
|
||||
args['desc'] if args['desc'] else ''
|
||||
copy_right = site.copyright if site.copyright else \
|
||||
args['copyright'] if args['copyright'] else ''
|
||||
privacy_policy = site.privacy_policy if site.privacy_policy else \
|
||||
args['privacy_policy'] if args['privacy_policy'] else ''
|
||||
custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \
|
||||
args['custom_disclaimer'] if args['custom_disclaimer'] else ''
|
||||
desc = site.description if site.description else args["desc"] if args["desc"] else ""
|
||||
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = (
|
||||
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
)
|
||||
custom_disclaimer = (
|
||||
site.custom_disclaimer
|
||||
if site.custom_disclaimer
|
||||
else args["custom_disclaimer"]
|
||||
if args["custom_disclaimer"]
|
||||
else ""
|
||||
)
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
|
||||
if not recommended_app:
|
||||
recommended_app = RecommendedApp(
|
||||
|
@ -83,9 +87,9 @@ class InsertExploreAppListApi(Resource):
|
|||
copyright=copy_right,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
language=args['language'],
|
||||
category=args['category'],
|
||||
position=args['position']
|
||||
language=args["language"],
|
||||
category=args["category"],
|
||||
position=args["position"],
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
|
@ -93,21 +97,21 @@ class InsertExploreAppListApi(Resource):
|
|||
app.is_public = True
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
return {"result": "success"}, 201
|
||||
else:
|
||||
recommended_app.description = desc
|
||||
recommended_app.copyright = copy_right
|
||||
recommended_app.privacy_policy = privacy_policy
|
||||
recommended_app.custom_disclaimer = custom_disclaimer
|
||||
recommended_app.language = args['language']
|
||||
recommended_app.category = args['category']
|
||||
recommended_app.position = args['position']
|
||||
recommended_app.language = args["language"]
|
||||
recommended_app.category = args["category"]
|
||||
recommended_app.position = args["position"]
|
||||
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class InsertExploreAppApi(Resource):
|
||||
|
@ -116,15 +120,14 @@ class InsertExploreAppApi(Resource):
|
|||
def delete(self, app_id):
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
|
||||
if not recommended_app:
|
||||
return {'result': 'success'}, 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
app = App.query.filter(App.id == recommended_app.app_id).first()
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
installed_apps = InstalledApp.query.filter(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
|
||||
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
|
||||
).all()
|
||||
|
||||
for installed_app in installed_apps:
|
||||
|
@ -133,8 +136,8 @@ class InsertExploreAppApi(Resource):
|
|||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps')
|
||||
api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/<uuid:app_id>')
|
||||
api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps")
|
||||
api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>")
|
||||
|
|
|
@ -14,26 +14,21 @@ from .setup import setup_required
|
|||
from .wraps import account_initialization_required
|
||||
|
||||
api_key_fields = {
|
||||
'id': fields.String,
|
||||
'type': fields.String,
|
||||
'token': fields.String,
|
||||
'last_used_at': TimestampField,
|
||||
'created_at': TimestampField
|
||||
"id": fields.String,
|
||||
"type": fields.String,
|
||||
"token": fields.String,
|
||||
"last_used_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
api_key_list = {
|
||||
'data': fields.List(fields.Nested(api_key_fields), attribute="items")
|
||||
}
|
||||
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
resource = resource_model.query.filter_by(
|
||||
id=resource_id, tenant_id=tenant_id
|
||||
).first()
|
||||
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
|
||||
|
||||
if resource is None:
|
||||
flask_restful.abort(
|
||||
404, message=f"{resource_model.__name__} not found.")
|
||||
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
|
||||
|
||||
return resource
|
||||
|
||||
|
@ -50,30 +45,32 @@ class BaseApiKeyListResource(Resource):
|
|||
@marshal_with(api_key_list)
|
||||
def get(self, resource_id):
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id,
|
||||
self.resource_model)
|
||||
keys = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
|
||||
all()
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
keys = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
.all()
|
||||
)
|
||||
return {"items": keys}
|
||||
|
||||
@marshal_with(api_key_fields)
|
||||
def post(self, resource_id):
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id,
|
||||
self.resource_model)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
|
||||
count()
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
flask_restful.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code='max_keys_exceeded'
|
||||
code="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
|
@ -97,79 +94,78 @@ class BaseApiKeyResource(Resource):
|
|||
def delete(self, resource_id, api_key_id):
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id,
|
||||
self.resource_model)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = db.session.query(ApiToken). \
|
||||
filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \
|
||||
first()
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(
|
||||
getattr(ApiToken, self.resource_id_field) == resource_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if key is None:
|
||||
flask_restful.abort(404, message='API key not found')
|
||||
flask_restful.abort(404, message="API key not found")
|
||||
|
||||
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = 'app'
|
||||
resource_type = "app"
|
||||
resource_model = App
|
||||
resource_id_field = 'app_id'
|
||||
token_prefix = 'app-'
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
|
||||
|
||||
class AppApiKeyResource(BaseApiKeyResource):
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = 'app'
|
||||
resource_type = "app"
|
||||
resource_model = App
|
||||
resource_id_field = 'app_id'
|
||||
resource_id_field = "app_id"
|
||||
|
||||
|
||||
class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = 'dataset'
|
||||
resource_type = "dataset"
|
||||
resource_model = Dataset
|
||||
resource_id_field = 'dataset_id'
|
||||
token_prefix = 'ds-'
|
||||
resource_id_field = "dataset_id"
|
||||
token_prefix = "ds-"
|
||||
|
||||
|
||||
class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
resource_type = 'dataset'
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_model = Dataset
|
||||
resource_id_field = 'dataset_id'
|
||||
resource_id_field = "dataset_id"
|
||||
|
||||
|
||||
api.add_resource(AppApiKeyListResource, '/apps/<uuid:resource_id>/api-keys')
|
||||
api.add_resource(AppApiKeyResource,
|
||||
'/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiKeyListResource,
|
||||
'/datasets/<uuid:resource_id>/api-keys')
|
||||
api.add_resource(DatasetApiKeyResource,
|
||||
'/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys")
|
||||
api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys")
|
||||
api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
|
|
|
@ -8,19 +8,18 @@ from services.advanced_prompt_template_service import AdvancedPromptTemplateServ
|
|||
|
||||
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('app_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('model_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
|
||||
parser.add_argument('model_name', type=str, required=True, location='args')
|
||||
parser.add_argument("app_mode", type=str, required=True, location="args")
|
||||
parser.add_argument("model_mode", type=str, required=True, location="args")
|
||||
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
|
||||
parser.add_argument("model_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
|
||||
|
||||
api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates")
|
||||
|
|
|
@ -18,15 +18,12 @@ class AgentLogApi(Resource):
|
|||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=uuid_value, required=True, location='args')
|
||||
parser.add_argument('conversation_id', type=uuid_value, required=True, location='args')
|
||||
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
|
||||
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return AgentService.get_agent_logs(
|
||||
app_model,
|
||||
args['conversation_id'],
|
||||
args['message_id']
|
||||
)
|
||||
|
||||
api.add_resource(AgentLogApi, '/apps/<uuid:app_id>/agent/logs')
|
||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||
|
||||
|
||||
api.add_resource(AgentLogApi, "/apps/<uuid:app_id>/agent/logs")
|
||||
|
|
|
@ -21,23 +21,23 @@ class AnnotationReplyActionApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def post(self, app_id, action):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
|
||||
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
if action == 'enable':
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
elif action == 'disable':
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
else:
|
||||
raise ValueError('Unsupported annotation reply action')
|
||||
raise ValueError("Unsupported annotation reply action")
|
||||
return result, 200
|
||||
|
||||
|
||||
|
@ -66,7 +66,7 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
|
@ -77,28 +77,24 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def get(self, app_id, job_id, action):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
|
||||
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
|
||||
error_msg = ""
|
||||
if job_status == "error":
|
||||
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
|
@ -109,18 +105,18 @@ class AnnotationListApi(Resource):
|
|||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
keyword = request.args.get('keyword', default=None, type=str)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default=None, type=str)
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
response = {
|
||||
'data': marshal(annotation_list, annotation_fields),
|
||||
'has_more': len(annotation_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
"data": marshal(annotation_list, annotation_fields),
|
||||
"has_more": len(annotation_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
@ -135,9 +131,7 @@ class AnnotationExportApi(Resource):
|
|||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response = {
|
||||
'data': marshal(annotation_list, annotation_fields)
|
||||
}
|
||||
response = {"data": marshal(annotation_list, annotation_fields)}
|
||||
return response, 200
|
||||
|
||||
|
||||
|
@ -145,7 +139,7 @@ class AnnotationCreateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
|
@ -153,8 +147,8 @@ class AnnotationCreateApi(Resource):
|
|||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
return annotation
|
||||
|
@ -164,7 +158,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
|
@ -173,8 +167,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
|
@ -189,29 +183,29 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class AnnotationBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
file = request.files["file"]
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
if not file.filename.endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
|
@ -220,27 +214,23 @@ class AnnotationBatchImportStatusApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def get(self, app_id, job_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
|
||||
error_msg = ""
|
||||
if job_status == "error":
|
||||
indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id))
|
||||
error_msg = redis_client.get(indexing_error_msg_key).decode()
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
class AnnotationHitHistoryListApi(Resource):
|
||||
|
@ -251,30 +241,32 @@ class AnnotationHitHistoryListApi(Resource):
|
|||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
|
||||
page, limit)
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
|
||||
app_id, annotation_id, page, limit
|
||||
)
|
||||
response = {
|
||||
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
'has_more': len(annotation_hit_history_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
"has_more": len(annotation_hit_history_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
|
||||
api.add_resource(AnnotationReplyActionStatusApi,
|
||||
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
|
||||
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
|
||||
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
|
||||
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
|
||||
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
|
||||
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
|
||||
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')
|
||||
api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
api.add_resource(
|
||||
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
|
||||
)
|
||||
api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations")
|
||||
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
|
||||
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")
|
||||
api.add_resource(AnnotationBatchImportStatusApi, "/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
|
||||
api.add_resource(AnnotationHitHistoryListApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
|
||||
api.add_resource(AppAnnotationSettingDetailApi, "/apps/<uuid:app_id>/annotation-setting")
|
||||
api.add_resource(AppAnnotationSettingUpdateApi, "/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
|
||||
|
|
|
@ -18,27 +18,35 @@ from libs.login import login_required
|
|||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
|
||||
class AppListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""Get app list"""
|
||||
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(',')]
|
||||
return [str(uuid.UUID(v)) for v in value.split(",")]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
|
||||
parser.add_argument('name', type=str, location='args', required=False)
|
||||
parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
|
||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
choices=["chat", "workflow", "agent-chat", "channel", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument("name", type=str, location="args", required=False)
|
||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -46,7 +54,7 @@ class AppListApi(Resource):
|
|||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
return marshal(app_pagination, app_pagination_fields)
|
||||
|
||||
|
@ -54,22 +62,23 @@ class AppListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields)
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if 'mode' not in args or args['mode'] is None:
|
||||
if "mode" not in args or args["mode"] is None:
|
||||
raise BadRequest("mode is required")
|
||||
|
||||
app_service = AppService()
|
||||
|
@ -83,7 +92,7 @@ class AppImportApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
"""Import app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
|
@ -91,18 +100,16 @@ class AppImportApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=args['data'],
|
||||
args=args,
|
||||
account=current_user
|
||||
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
@ -113,7 +120,7 @@ class AppImportFromUrlApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
"""Import app from url"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
|
@ -121,25 +128,21 @@ class AppImportFromUrlApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('url', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app_from_url(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
url=args['url'],
|
||||
args=args,
|
||||
account=current_user
|
||||
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -163,13 +166,15 @@ class AppApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('max_active_requests', type=int, location='json')
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("max_active_requests", type=int, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
|
@ -190,7 +195,7 @@ class AppApi(Resource):
|
|||
app_service = AppService()
|
||||
app_service.delete_app(app_model)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class AppCopyApi(Resource):
|
||||
|
@ -206,18 +211,16 @@ class AppCopyApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=data,
|
||||
args=args,
|
||||
account=current_user
|
||||
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
@ -236,12 +239,10 @@ class AppExportApi(Resource):
|
|||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args')
|
||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret'])
|
||||
}
|
||||
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
|
||||
|
||||
|
||||
class AppNameApi(Resource):
|
||||
|
@ -254,13 +255,13 @@ class AppNameApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_name(app_model, args.get('name'))
|
||||
app_model = app_service.update_app_name(app_model, args.get("name"))
|
||||
|
||||
return app_model
|
||||
|
||||
|
@ -275,14 +276,14 @@ class AppIconApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background'))
|
||||
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
|
||||
|
||||
return app_model
|
||||
|
||||
|
@ -297,13 +298,13 @@ class AppSiteStatus(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('enable_site', type=bool, required=True, location='json')
|
||||
parser.add_argument("enable_site", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_site_status(app_model, args.get('enable_site'))
|
||||
app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
|
||||
|
||||
return app_model
|
||||
|
||||
|
@ -318,13 +319,13 @@ class AppApiStatus(Resource):
|
|||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('enable_api', type=bool, required=True, location='json')
|
||||
parser.add_argument("enable_api", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_api_status(app_model, args.get('enable_api'))
|
||||
app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
|
||||
|
||||
return app_model
|
||||
|
||||
|
@ -335,9 +336,7 @@ class AppTraceApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
"""Get app trace"""
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(
|
||||
app_id=app_id
|
||||
)
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id)
|
||||
|
||||
return app_trace_config
|
||||
|
||||
|
@ -349,27 +348,27 @@ class AppTraceApi(Resource):
|
|||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('enabled', type=bool, required=True, location='json')
|
||||
parser.add_argument('tracing_provider', type=str, required=True, location='json')
|
||||
parser.add_argument("enabled", type=bool, required=True, location="json")
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=app_id,
|
||||
enabled=args['enabled'],
|
||||
tracing_provider=args['tracing_provider'],
|
||||
enabled=args["enabled"],
|
||||
tracing_provider=args["tracing_provider"],
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(AppListApi, '/apps')
|
||||
api.add_resource(AppImportApi, '/apps/import')
|
||||
api.add_resource(AppImportFromUrlApi, '/apps/import/url')
|
||||
api.add_resource(AppApi, '/apps/<uuid:app_id>')
|
||||
api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
|
||||
api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')
|
||||
api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name')
|
||||
api.add_resource(AppIconApi, '/apps/<uuid:app_id>/icon')
|
||||
api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
|
||||
api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
|
||||
api.add_resource(AppTraceApi, '/apps/<uuid:app_id>/trace')
|
||||
api.add_resource(AppListApi, "/apps")
|
||||
api.add_resource(AppImportApi, "/apps/import")
|
||||
api.add_resource(AppImportFromUrlApi, "/apps/import/url")
|
||||
api.add_resource(AppApi, "/apps/<uuid:app_id>")
|
||||
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
|
||||
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")
|
||||
api.add_resource(AppNameApi, "/apps/<uuid:app_id>/name")
|
||||
api.add_resource(AppIconApi, "/apps/<uuid:app_id>/icon")
|
||||
api.add_resource(AppSiteStatus, "/apps/<uuid:app_id>/site-enable")
|
||||
api.add_resource(AppApiStatus, "/apps/<uuid:app_id>/api-enable")
|
||||
api.add_resource(AppTraceApi, "/apps/<uuid:app_id>/trace")
|
||||
|
|
|
@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource):
|
|||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model):
|
||||
file = request.files['file']
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(
|
||||
|
@ -85,31 +85,31 @@ class ChatMessageTextApi(Resource):
|
|||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
parser.add_argument("message_id", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=text,
|
||||
message_id=message_id,
|
||||
voice=voice
|
||||
)
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
|
@ -145,12 +145,12 @@ class TextModesApi(Resource):
|
|||
def get(self, app_model):
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('language', type=str, required=True, location='args')
|
||||
parser.add_argument("language", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
language=args['language'],
|
||||
language=args["language"],
|
||||
)
|
||||
|
||||
return response
|
||||
|
@ -179,6 +179,6 @@ class TextModesApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
|
||||
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')
|
||||
api.add_resource(ChatMessageAudioApi, "/apps/<uuid:app_id>/audio-to-text")
|
||||
api.add_resource(ChatMessageTextApi, "/apps/<uuid:app_id>/text-to-audio")
|
||||
api.add_resource(TextModesApi, "/apps/<uuid:app_id>/text-to-audio/voices")
|
||||
|
|
|
@ -17,6 +17,7 @@ from controllers.console.app.error import (
|
|||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
|
@ -31,37 +32,33 @@ from libs.helper import uuid_value
|
|||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
class CompletionMessageApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
args['auto_generate_name'] = False
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=account,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=streaming
|
||||
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
@ -97,7 +94,7 @@ class CompletionMessageStopApi(Resource):
|
|||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class ChatMessageApi(Resource):
|
||||
|
@ -107,27 +104,23 @@ class ChatMessageApi(Resource):
|
|||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
args['auto_generate_name'] = False
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=account,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=streaming
|
||||
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
@ -144,6 +137,8 @@ class ChatMessageApi(Resource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
|
@ -163,10 +158,10 @@ class ChatMessageStopApi(Resource):
|
|||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(CompletionMessageApi, '/apps/<uuid:app_id>/completion-messages')
|
||||
api.add_resource(CompletionMessageStopApi, '/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop')
|
||||
api.add_resource(ChatMessageApi, '/apps/<uuid:app_id>/chat-messages')
|
||||
api.add_resource(ChatMessageStopApi, '/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop')
|
||||
api.add_resource(CompletionMessageApi, "/apps/<uuid:app_id>/completion-messages")
|
||||
api.add_resource(CompletionMessageStopApi, "/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
|
||||
api.add_resource(ChatMessageApi, "/apps/<uuid:app_id>/chat-messages")
|
||||
api.add_resource(ChatMessageStopApi, "/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
|
||||
|
|
|
@ -26,34 +26,32 @@ from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotat
|
|||
|
||||
|
||||
class CompletionConversationApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_pagination_fields)
|
||||
def get(self, app_model):
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('keyword', type=str, location='args')
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('annotation_status', type=str,
|
||||
choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
|
||||
parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion')
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
|
||||
|
||||
if args['keyword']:
|
||||
query = query.join(
|
||||
Message, Message.conversation_id == Conversation.id
|
||||
).filter(
|
||||
if args["keyword"]:
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).filter(
|
||||
or_(
|
||||
Message.query.ilike('%{}%'.format(args['keyword'])),
|
||||
Message.answer.ilike('%{}%'.format(args['keyword']))
|
||||
Message.query.ilike("%{}%".format(args["keyword"])),
|
||||
Message.answer.ilike("%{}%".format(args["keyword"])),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -61,8 +59,8 @@ class CompletionConversationApi(Resource):
|
|||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
|
@ -70,8 +68,8 @@ class CompletionConversationApi(Resource):
|
|||
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
|
@ -79,36 +77,32 @@ class CompletionConversationApi(Resource):
|
|||
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
if args['annotation_status'] == "annotated":
|
||||
if args["annotation_status"] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join(
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args['annotation_status'] == "not_annotated":
|
||||
query = query.outerjoin(
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(
|
||||
query,
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False
|
||||
)
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
class CompletionConversationDetailApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_message_detail_fields)
|
||||
def get(self, app_model, conversation_id):
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
|
@ -119,12 +113,15 @@ class CompletionConversationDetailApi(Resource):
|
|||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
conversation = db.session.query(Conversation) \
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
@ -132,11 +129,10 @@ class CompletionConversationDetailApi(Resource):
|
|||
conversation.is_deleted = True
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ChatConversationApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -146,20 +142,28 @@ class ChatConversationApi(Resource):
|
|||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('keyword', type=str, location='args')
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('annotation_status', type=str,
|
||||
choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
|
||||
parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
|
||||
parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
Conversation.id.label('conversation_id'),
|
||||
EndUser.session_id.label('from_end_user_session_id')
|
||||
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
|
||||
)
|
||||
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
||||
.subquery()
|
||||
|
@ -167,19 +171,19 @@ class ChatConversationApi(Resource):
|
|||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id)
|
||||
|
||||
if args['keyword']:
|
||||
keyword_filter = '%{}%'.format(args['keyword'])
|
||||
query = query.join(
|
||||
Message, Message.conversation_id == Conversation.id,
|
||||
).join(
|
||||
subquery, subquery.c.conversation_id == Conversation.id
|
||||
).filter(
|
||||
if args["keyword"]:
|
||||
keyword_filter = "%{}%".format(args["keyword"])
|
||||
message_subquery = (
|
||||
db.session.query(Message.conversation_id)
|
||||
.filter(or_(Message.query.ilike(keyword_filter), Message.answer.ilike(keyword_filter)))
|
||||
.subquery()
|
||||
)
|
||||
query = query.join(subquery, subquery.c.conversation_id == Conversation.id).filter(
|
||||
or_(
|
||||
Message.query.ilike(keyword_filter),
|
||||
Message.answer.ilike(keyword_filter),
|
||||
Conversation.id.in_(message_subquery),
|
||||
Conversation.name.ilike(keyword_filter),
|
||||
Conversation.introduction.ilike(keyword_filter),
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter)
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter),
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -187,8 +191,8 @@ class ChatConversationApi(Resource):
|
|||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
|
@ -196,8 +200,8 @@ class ChatConversationApi(Resource):
|
|||
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
|
@ -205,40 +209,46 @@ class ChatConversationApi(Resource):
|
|||
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
if args['annotation_status'] == "annotated":
|
||||
if args["annotation_status"] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join(
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args['annotation_status'] == "not_annotated":
|
||||
query = query.outerjoin(
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
|
||||
if args['message_count_gte'] and args['message_count_gte'] >= 1:
|
||||
if args["message_count_gte"] and args["message_count_gte"] >= 1:
|
||||
query = (
|
||||
query.options(joinedload(Conversation.messages))
|
||||
.join(Message, Message.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(Message.id) >= args['message_count_gte'])
|
||||
.having(func.count(Message.id) >= args["message_count_gte"])
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
match args["sort_by"]:
|
||||
case "created_at":
|
||||
query = query.order_by(Conversation.created_at.asc())
|
||||
case "-created_at":
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
case "updated_at":
|
||||
query = query.order_by(Conversation.updated_at.asc())
|
||||
case "-updated_at":
|
||||
query = query.order_by(Conversation.updated_at.desc())
|
||||
case _:
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(
|
||||
query,
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False
|
||||
)
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
class ChatConversationDetailApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -256,12 +266,15 @@ class ChatConversationDetailApi(Resource):
|
|||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
conversation = db.session.query(Conversation) \
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
@ -269,18 +282,21 @@ class ChatConversationDetailApi(Resource):
|
|||
conversation.is_deleted = True
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
|
||||
api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
|
||||
api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
|
||||
api.add_resource(ChatConversationDetailApi, '/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>')
|
||||
api.add_resource(CompletionConversationApi, "/apps/<uuid:app_id>/completion-conversations")
|
||||
api.add_resource(CompletionConversationDetailApi, "/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
|
||||
api.add_resource(ChatConversationApi, "/apps/<uuid:app_id>/chat-conversations")
|
||||
api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
conversation = db.session.query(Conversation) \
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
|
@ -21,7 +21,7 @@ class ConversationVariablesApi(Resource):
|
|||
@marshal_with(paginated_conversation_variable_fields)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('conversation_id', type=str, location='args')
|
||||
parser.add_argument("conversation_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
stmt = (
|
||||
|
@ -29,10 +29,10 @@ class ConversationVariablesApi(Resource):
|
|||
.where(ConversationVariable.app_id == app_model.id)
|
||||
.order_by(ConversationVariable.created_at)
|
||||
)
|
||||
if args['conversation_id']:
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id'])
|
||||
if args["conversation_id"]:
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
|
||||
else:
|
||||
raise ValueError('conversation_id is required')
|
||||
raise ValueError("conversation_id is required")
|
||||
|
||||
# NOTE: This is a temporary solution to avoid performance issues.
|
||||
page = 1
|
||||
|
@ -43,14 +43,14 @@ class ConversationVariablesApi(Resource):
|
|||
rows = session.scalars(stmt).all()
|
||||
|
||||
return {
|
||||
'page': page,
|
||||
'limit': page_size,
|
||||
'total': len(rows),
|
||||
'has_more': False,
|
||||
'data': [
|
||||
"page": page,
|
||||
"limit": page_size,
|
||||
"total": len(rows),
|
||||
"has_more": False,
|
||||
"data": [
|
||||
{
|
||||
'created_at': row.created_at,
|
||||
'updated_at': row.updated_at,
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
**row.to_variable().model_dump(),
|
||||
}
|
||||
for row in rows
|
||||
|
@ -58,4 +58,4 @@ class ConversationVariablesApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables')
|
||||
api.add_resource(ConversationVariablesApi, "/apps/<uuid:app_id>/conversation-variables")
|
||||
|
|
|
@ -2,116 +2,128 @@ from libs.exception import BaseHTTPException
|
|||
|
||||
|
||||
class AppNotFoundError(BaseHTTPException):
|
||||
error_code = 'app_not_found'
|
||||
error_code = "app_not_found"
|
||||
description = "App not found."
|
||||
code = 404
|
||||
|
||||
|
||||
class ProviderNotInitializeError(BaseHTTPException):
|
||||
error_code = 'provider_not_initialize'
|
||||
description = "No valid model provider credentials found. " \
|
||||
"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
error_code = "provider_not_initialize"
|
||||
description = (
|
||||
"No valid model provider credentials found. "
|
||||
"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderQuotaExceededError(BaseHTTPException):
|
||||
error_code = 'provider_quota_exceeded'
|
||||
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
error_code = "provider_quota_exceeded"
|
||||
description = (
|
||||
"Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
)
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
|
||||
error_code = 'model_currently_not_support'
|
||||
error_code = "model_currently_not_support"
|
||||
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
|
||||
code = 400
|
||||
|
||||
|
||||
class ConversationCompletedError(BaseHTTPException):
|
||||
error_code = 'conversation_completed'
|
||||
error_code = "conversation_completed"
|
||||
description = "The conversation has ended. Please start a new conversation."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppUnavailableError(BaseHTTPException):
|
||||
error_code = 'app_unavailable'
|
||||
error_code = "app_unavailable"
|
||||
description = "App unavailable, please check your app configurations."
|
||||
code = 400
|
||||
|
||||
|
||||
class CompletionRequestError(BaseHTTPException):
|
||||
error_code = 'completion_request_error'
|
||||
error_code = "completion_request_error"
|
||||
description = "Completion request failed."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppMoreLikeThisDisabledError(BaseHTTPException):
|
||||
error_code = 'app_more_like_this_disabled'
|
||||
error_code = "app_more_like_this_disabled"
|
||||
description = "The 'More like this' feature is disabled. Please refresh your page."
|
||||
code = 403
|
||||
|
||||
|
||||
class NoAudioUploadedError(BaseHTTPException):
|
||||
error_code = 'no_audio_uploaded'
|
||||
error_code = "no_audio_uploaded"
|
||||
description = "Please upload your audio."
|
||||
code = 400
|
||||
|
||||
|
||||
class AudioTooLargeError(BaseHTTPException):
|
||||
error_code = 'audio_too_large'
|
||||
error_code = "audio_too_large"
|
||||
description = "Audio size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_audio_type'
|
||||
error_code = "unsupported_audio_type"
|
||||
description = "Audio type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
error_code = "provider_not_support_speech_to_text"
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class DraftWorkflowNotExist(BaseHTTPException):
|
||||
error_code = 'draft_workflow_not_exist'
|
||||
error_code = "draft_workflow_not_exist"
|
||||
description = "Draft workflow need to be initialized."
|
||||
code = 400
|
||||
|
||||
|
||||
class DraftWorkflowNotSync(BaseHTTPException):
|
||||
error_code = 'draft_workflow_not_sync'
|
||||
error_code = "draft_workflow_not_sync"
|
||||
description = "Workflow graph might have been modified, please refresh and resubmit."
|
||||
code = 400
|
||||
|
||||
|
||||
class TracingConfigNotExist(BaseHTTPException):
|
||||
error_code = 'trace_config_not_exist'
|
||||
error_code = "trace_config_not_exist"
|
||||
description = "Trace config not exist."
|
||||
code = 400
|
||||
|
||||
|
||||
class TracingConfigIsExist(BaseHTTPException):
|
||||
error_code = 'trace_config_is_exist'
|
||||
error_code = "trace_config_is_exist"
|
||||
description = "Trace config is exist."
|
||||
code = 400
|
||||
|
||||
|
||||
class TracingConfigCheckError(BaseHTTPException):
|
||||
error_code = 'trace_config_check_error'
|
||||
error_code = "trace_config_check_error"
|
||||
description = "Invalid Credentials."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvokeRateLimitError(BaseHTTPException):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
|
|
|
@ -24,21 +24,21 @@ class RuleGenerateApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('instruction', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('no_variable', type=bool, required=True, default=False, location='json')
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512'))
|
||||
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args['instruction'],
|
||||
model_config=args['model_config'],
|
||||
no_variable=args['no_variable'],
|
||||
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=args["no_variable"],
|
||||
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
@ -52,4 +52,4 @@ class RuleGenerateApi(Resource):
|
|||
return rules
|
||||
|
||||
|
||||
api.add_resource(RuleGenerateApi, '/rule-generate')
|
||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||
|
|
|
@ -33,9 +33,9 @@ from services.message_service import MessageService
|
|||
|
||||
class ChatMessageListApi(Resource):
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_detail_fields))
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_detail_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
|
@ -45,55 +45,69 @@ class ChatMessageListApi(Resource):
|
|||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
conversation = db.session.query(Conversation).filter(
|
||||
Conversation.id == args['conversation_id'],
|
||||
Conversation.app_id == app_model.id
|
||||
).first()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args['first_id']:
|
||||
first_message = db.session.query(Message) \
|
||||
.filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first()
|
||||
if args["first_id"]:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
|
||||
.first()
|
||||
)
|
||||
|
||||
if not first_message:
|
||||
raise NotFound("First message not found")
|
||||
|
||||
history_messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
Message.id != first_message.id
|
||||
) \
|
||||
.order_by(Message.created_at.desc()).limit(args['limit']).all()
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
Message.id != first_message.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
|
||||
.order_by(Message.created_at.desc()).limit(args['limit']).all()
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.all()
|
||||
)
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) == args['limit']:
|
||||
if len(history_messages) == args["limit"]:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = db.session.query(Message).filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id
|
||||
).count()
|
||||
rest_count = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=history_messages,
|
||||
limit=args['limit'],
|
||||
has_more=has_more
|
||||
)
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
|
||||
|
||||
class MessageFeedbackApi(Resource):
|
||||
|
@ -103,49 +117,46 @@ class MessageFeedbackApi(Resource):
|
|||
@get_app_model
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = str(args['message_id'])
|
||||
message_id = str(args["message_id"])
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id
|
||||
).first()
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
feedback = message.admin_feedback
|
||||
|
||||
if not args['rating'] and feedback:
|
||||
if not args["rating"] and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args['rating'] and feedback:
|
||||
feedback.rating = args['rating']
|
||||
elif not args['rating'] and not feedback:
|
||||
raise ValueError('rating cannot be None when feedback not exists')
|
||||
elif args["rating"] and feedback:
|
||||
feedback.rating = args["rating"]
|
||||
elif not args["rating"] and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=args['rating'],
|
||||
from_source='admin',
|
||||
from_account_id=current_user.id
|
||||
rating=args["rating"],
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
db.session.add(feedback)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class MessageAnnotationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@get_app_model
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model):
|
||||
|
@ -153,10 +164,10 @@ class MessageAnnotationApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
|
||||
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
||||
|
||||
|
@ -169,11 +180,9 @@ class MessageAnnotationCountApi(Resource):
|
|||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
count = db.session.query(MessageAnnotation).filter(
|
||||
MessageAnnotation.app_id == app_model.id
|
||||
).count()
|
||||
count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count()
|
||||
|
||||
return {'count': count}
|
||||
return {"count": count}
|
||||
|
||||
|
||||
class MessageSuggestedQuestionApi(Resource):
|
||||
|
@ -186,10 +195,7 @@ class MessageSuggestedQuestionApi(Resource):
|
|||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=current_user,
|
||||
invoke_from=InvokeFrom.DEBUGGER
|
||||
app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
|
@ -209,7 +215,7 @@ class MessageSuggestedQuestionApi(Resource):
|
|||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {'data': questions}
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
class MessageApi(Resource):
|
||||
|
@ -221,10 +227,7 @@ class MessageApi(Resource):
|
|||
def get(self, app_model, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id
|
||||
).first()
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
@ -232,9 +235,9 @@ class MessageApi(Resource):
|
|||
return message
|
||||
|
||||
|
||||
api.add_resource(MessageSuggestedQuestionApi, '/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions')
|
||||
api.add_resource(ChatMessageListApi, '/apps/<uuid:app_id>/chat-messages', endpoint='console_chat_messages')
|
||||
api.add_resource(MessageFeedbackApi, '/apps/<uuid:app_id>/feedbacks')
|
||||
api.add_resource(MessageAnnotationApi, '/apps/<uuid:app_id>/annotations')
|
||||
api.add_resource(MessageAnnotationCountApi, '/apps/<uuid:app_id>/annotations/count')
|
||||
api.add_resource(MessageApi, '/apps/<uuid:app_id>/messages/<uuid:message_id>', endpoint='console_message')
|
||||
api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
|
||||
api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages")
|
||||
api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks")
|
||||
api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations")
|
||||
api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count")
|
||||
api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message")
|
||||
|
|
|
@ -19,37 +19,35 @@ from services.app_model_config_service import AppModelConfigService
|
|||
|
||||
|
||||
class ModelConfigResource(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
|
||||
"""Modify app model config"""
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
config=request.json,
|
||||
app_mode=AppMode.value_of(app_model.mode)
|
||||
tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode)
|
||||
)
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
updated_by=current_user.id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
# get original app model config
|
||||
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == app_model.app_model_config_id
|
||||
).first()
|
||||
original_app_model_config: AppModelConfig = (
|
||||
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
parameter_map = {}
|
||||
masked_parameter_map = {}
|
||||
tool_map = {}
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
|
||||
|
@ -66,7 +64,7 @@ class ModelConfigResource(Resource):
|
|||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f'AGENT.{app_model.id}'
|
||||
identity_id=f"AGENT.{app_model.id}",
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
@ -79,18 +77,18 @@ class ModelConfigResource(Resource):
|
|||
parameters = {}
|
||||
masked_parameter = {}
|
||||
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||
masked_parameter_map[key] = masked_parameter
|
||||
parameter_map[key] = parameters
|
||||
tool_map[key] = tool_runtime
|
||||
|
||||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
|
||||
# get tool
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||
if key in tool_map:
|
||||
tool_runtime = tool_map[key]
|
||||
else:
|
||||
|
@ -108,7 +106,7 @@ class ModelConfigResource(Resource):
|
|||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f'AGENT.{app_model.id}'
|
||||
identity_id=f"AGENT.{app_model.id}",
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
|
@ -116,15 +114,17 @@ class ModelConfigResource(Resource):
|
|||
if agent_tool_entity.tool_parameters:
|
||||
if key not in masked_parameter_map:
|
||||
continue
|
||||
|
||||
|
||||
for masked_key, masked_value in masked_parameter_map[key].items():
|
||||
if masked_key in agent_tool_entity.tool_parameters and \
|
||||
agent_tool_entity.tool_parameters[masked_key] == masked_value:
|
||||
if (
|
||||
masked_key in agent_tool_entity.tool_parameters
|
||||
and agent_tool_entity.tool_parameters[masked_key] == masked_value
|
||||
):
|
||||
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
|
||||
|
||||
# encrypt parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
|
||||
# update app model config
|
||||
new_app_model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
@ -135,12 +135,9 @@ class ModelConfigResource(Resource):
|
|||
app_model.app_model_config_id = new_app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
app_model_config_was_updated.send(
|
||||
app_model,
|
||||
app_model_config=new_app_model_config
|
||||
)
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
||||
return {'result': 'success'}
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(ModelConfigResource, '/apps/<uuid:app_id>/model-config')
|
||||
api.add_resource(ModelConfigResource, "/apps/<uuid:app_id>/model-config")
|
||||
|
|
|
@ -18,13 +18,11 @@ class TraceAppConfigApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tracing_provider', type=str, required=True, location='args')
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
trace_config = OpsService.get_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args['tracing_provider']
|
||||
)
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||
if not trace_config:
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
|
@ -37,19 +35,17 @@ class TraceAppConfigApi(Resource):
|
|||
def post(self, app_id):
|
||||
"""Create a new trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tracing_provider', type=str, required=True, location='json')
|
||||
parser.add_argument('tracing_config', type=dict, required=True, location='json')
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=app_id,
|
||||
tracing_provider=args['tracing_provider'],
|
||||
tracing_config=args['tracing_config']
|
||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
if result.get('error'):
|
||||
if result.get("error"):
|
||||
raise TracingConfigCheckError()
|
||||
return result
|
||||
except Exception as e:
|
||||
|
@ -61,15 +57,13 @@ class TraceAppConfigApi(Resource):
|
|||
def patch(self, app_id):
|
||||
"""Update an existing trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tracing_provider', type=str, required=True, location='json')
|
||||
parser.add_argument('tracing_config', type=dict, required=True, location='json')
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=app_id,
|
||||
tracing_provider=args['tracing_provider'],
|
||||
tracing_config=args['tracing_config']
|
||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
|
@ -83,14 +77,11 @@ class TraceAppConfigApi(Resource):
|
|||
def delete(self, app_id):
|
||||
"""Delete an existing trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tracing_provider', type=str, required=True, location='args')
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(
|
||||
app_id=app_id,
|
||||
tracing_provider=args['tracing_provider']
|
||||
)
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}
|
||||
|
@ -98,4 +89,4 @@ class TraceAppConfigApi(Resource):
|
|||
raise e
|
||||
|
||||
|
||||
api.add_resource(TraceAppConfigApi, '/apps/<uuid:app_id>/trace-config')
|
||||
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
@ -15,22 +17,24 @@ from models.model import Site
|
|||
|
||||
def parse_app_site_args():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('title', type=str, required=False, location='json')
|
||||
parser.add_argument('icon', type=str, required=False, location='json')
|
||||
parser.add_argument('icon_background', type=str, required=False, location='json')
|
||||
parser.add_argument('description', type=str, required=False, location='json')
|
||||
parser.add_argument('default_language', type=supported_language, required=False, location='json')
|
||||
parser.add_argument('chat_color_theme', type=str, required=False, location='json')
|
||||
parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json')
|
||||
parser.add_argument('customize_domain', type=str, required=False, location='json')
|
||||
parser.add_argument('copyright', type=str, required=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=False, location='json')
|
||||
parser.add_argument('custom_disclaimer', type=str, required=False, location='json')
|
||||
parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'],
|
||||
required=False,
|
||||
location='json')
|
||||
parser.add_argument('prompt_public', type=bool, required=False, location='json')
|
||||
parser.add_argument('show_workflow_steps', type=bool, required=False, location='json')
|
||||
parser.add_argument("title", type=str, required=False, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=False, location="json")
|
||||
parser.add_argument("icon", type=str, required=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||
parser.add_argument("customize_domain", type=str, required=False, location="json")
|
||||
parser.add_argument("copyright", type=str, required=False, location="json")
|
||||
parser.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||
parser.add_argument(
|
||||
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
|
||||
)
|
||||
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -47,37 +51,38 @@ class AppSite(Resource):
|
|||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site). \
|
||||
filter(Site.app_id == app_model.id). \
|
||||
one_or_404()
|
||||
site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404()
|
||||
|
||||
for attr_name in [
|
||||
'title',
|
||||
'icon',
|
||||
'icon_background',
|
||||
'description',
|
||||
'default_language',
|
||||
'chat_color_theme',
|
||||
'chat_color_theme_inverted',
|
||||
'customize_domain',
|
||||
'copyright',
|
||||
'privacy_policy',
|
||||
'custom_disclaimer',
|
||||
'customize_token_strategy',
|
||||
'prompt_public',
|
||||
'show_workflow_steps'
|
||||
"title",
|
||||
"icon_type",
|
||||
"icon",
|
||||
"icon_background",
|
||||
"description",
|
||||
"default_language",
|
||||
"chat_color_theme",
|
||||
"chat_color_theme_inverted",
|
||||
"customize_domain",
|
||||
"copyright",
|
||||
"privacy_policy",
|
||||
"custom_disclaimer",
|
||||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
class AppSiteAccessTokenReset(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -94,10 +99,12 @@ class AppSiteAccessTokenReset(Resource):
|
|||
raise NotFound
|
||||
|
||||
site.code = Site.generate_code(16)
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
api.add_resource(AppSite, '/apps/<uuid:app_id>/site')
|
||||
api.add_resource(AppSiteAccessTokenReset, '/apps/<uuid:app_id>/site/access-token-reset')
|
||||
api.add_resource(AppSite, "/apps/<uuid:app_id>/site")
|
||||
api.add_resource(AppSiteAccessTokenReset, "/apps/<uuid:app_id>/site/access-token-reset")
|
||||
|
|
|
@ -16,8 +16,7 @@ from libs.login import login_required
|
|||
from models.model import AppMode
|
||||
|
||||
|
||||
class DailyConversationStatistic(Resource):
|
||||
|
||||
class DailyMessageStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -26,58 +25,52 @@ class DailyConversationStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = '''
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count
|
||||
FROM messages where app_id = :app_id
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
sql_query += " GROUP BY date order by date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'conversation_count': i.conversation_count
|
||||
})
|
||||
response_data.append({"date": str(i.date), "message_count": i.message_count})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
class DailyTerminalsStatistic(Resource):
|
||||
|
||||
class DailyConversationStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -86,54 +79,103 @@ class DailyTerminalsStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = '''
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
|
||||
FROM messages where app_id = :app_id
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
sql_query += " GROUP BY date order by date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'terminal_count': i.terminal_count
|
||||
})
|
||||
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
class DailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
class DailyTokenCostStatistic(Resource):
|
||||
|
@ -145,58 +187,53 @@ class DailyTokenCostStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = '''
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
|
||||
sum(total_price) as total_price
|
||||
FROM messages where app_id = :app_id
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
sql_query += " GROUP BY date order by date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'token_count': i.token_count,
|
||||
'total_price': i.total_price,
|
||||
'currency': 'USD'
|
||||
})
|
||||
response_data.append(
|
||||
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
class AverageSessionInteractionStatistic(Resource):
|
||||
|
@ -208,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
|
@ -218,30 +255,30 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
|||
FROM conversations c
|
||||
JOIN messages m ON c.id = m.conversation_id
|
||||
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and c.created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
sql_query += " and c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and c.created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
sql_query += " and c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += """
|
||||
GROUP BY m.conversation_id) subquery
|
||||
|
@ -250,18 +287,15 @@ GROUP BY date
|
|||
ORDER BY date"""
|
||||
|
||||
response_data = []
|
||||
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'interactions': float(i.interactions.quantize(Decimal('0.01')))
|
||||
})
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
class UserSatisfactionRateStatistic(Resource):
|
||||
|
@ -273,57 +307,57 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = '''
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
|
||||
FROM messages m
|
||||
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
|
||||
WHERE m.app_id = :app_id
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and m.created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
sql_query += " and m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and m.created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
sql_query += " and m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
sql_query += " GROUP BY date order by date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
|
||||
})
|
||||
response_data.append(
|
||||
{
|
||||
"date": str(i.date),
|
||||
"rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
|
||||
}
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
class AverageResponseTimeStatistic(Resource):
|
||||
|
@ -335,56 +369,51 @@ class AverageResponseTimeStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = '''
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(provider_response_latency) as latency
|
||||
FROM messages
|
||||
WHERE app_id = :app_id
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
sql_query += " GROUP BY date order by date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'latency': round(i.latency * 1000, 4)
|
||||
})
|
||||
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
class TokensPerSecondStatistic(Resource):
|
||||
|
@ -396,63 +425,59 @@ class TokensPerSecondStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
CASE
|
||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||
END as tokens_per_second
|
||||
FROM messages
|
||||
WHERE app_id = :app_id'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
WHERE app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
sql_query += " GROUP BY date order by date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'tps': round(i.tokens_per_second, 4)
|
||||
})
|
||||
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
api.add_resource(DailyConversationStatistic, '/apps/<uuid:app_id>/statistics/daily-conversations')
|
||||
api.add_resource(DailyTerminalsStatistic, '/apps/<uuid:app_id>/statistics/daily-end-users')
|
||||
api.add_resource(DailyTokenCostStatistic, '/apps/<uuid:app_id>/statistics/token-costs')
|
||||
api.add_resource(AverageSessionInteractionStatistic, '/apps/<uuid:app_id>/statistics/average-session-interactions')
|
||||
api.add_resource(UserSatisfactionRateStatistic, '/apps/<uuid:app_id>/statistics/user-satisfaction-rate')
|
||||
api.add_resource(AverageResponseTimeStatistic, '/apps/<uuid:app_id>/statistics/average-response-time')
|
||||
api.add_resource(TokensPerSecondStatistic, '/apps/<uuid:app_id>/statistics/tokens-per-second')
|
||||
api.add_resource(DailyMessageStatistic, "/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
api.add_resource(DailyConversationStatistic, "/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
api.add_resource(DailyTerminalsStatistic, "/apps/<uuid:app_id>/statistics/daily-end-users")
|
||||
api.add_resource(DailyTokenCostStatistic, "/apps/<uuid:app_id>/statistics/token-costs")
|
||||
api.add_resource(AverageSessionInteractionStatistic, "/apps/<uuid:app_id>/statistics/average-session-interactions")
|
||||
api.add_resource(UserSatisfactionRateStatistic, "/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
|
||||
api.add_resource(AverageResponseTimeStatistic, "/apps/<uuid:app_id>/statistics/average-response-time")
|
||||
api.add_resource(TokensPerSecondStatistic, "/apps/<uuid:app_id>/statistics/tokens-per-second")
|
||||
|
|
|
@ -64,51 +64,51 @@ class DraftWorkflowApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
content_type = request.headers.get('Content-Type', '')
|
||||
|
||||
if 'application/json' in content_type:
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('hash', type=str, required=False, location='json')
|
||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("hash", type=str, required=False, location="json")
|
||||
# TODO: set this to required=True after frontend is updated
|
||||
parser.add_argument('environment_variables', type=list, required=False, location='json')
|
||||
parser.add_argument('conversation_variables', type=list, required=False, location='json')
|
||||
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
elif 'text/plain' in content_type:
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
data = json.loads(request.data.decode('utf-8'))
|
||||
if 'graph' not in data or 'features' not in data:
|
||||
raise ValueError('graph or features not found in data')
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
if "graph" not in data or "features" not in data:
|
||||
raise ValueError("graph or features not found in data")
|
||||
|
||||
if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict):
|
||||
raise ValueError('graph or features is not a dict')
|
||||
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
|
||||
raise ValueError("graph or features is not a dict")
|
||||
|
||||
args = {
|
||||
'graph': data.get('graph'),
|
||||
'features': data.get('features'),
|
||||
'hash': data.get('hash'),
|
||||
'environment_variables': data.get('environment_variables'),
|
||||
'conversation_variables': data.get('conversation_variables'),
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {'message': 'Invalid JSON data'}, 400
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get('environment_variables') or []
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
conversation_variables_list = args.get('conversation_variables') or []
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args['graph'],
|
||||
features=args['features'],
|
||||
unique_hash=args.get('hash'),
|
||||
graph=args["graph"],
|
||||
features=args["features"],
|
||||
unique_hash=args.get("hash"),
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
|
@ -119,7 +119,7 @@ class DraftWorkflowApi(Resource):
|
|||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
|
@ -138,13 +138,11 @@ class DraftWorkflowImportApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow = AppDslService.import_and_overwrite_workflow(
|
||||
app_model=app_model,
|
||||
data=args['data'],
|
||||
account=current_user
|
||||
app_model=app_model, data=args["data"], account=current_user
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
@ -162,21 +160,17 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json', default='')
|
||||
parser.add_argument('files', type=list, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json", default="")
|
||||
parser.add_argument("files", type=list, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
@ -190,6 +184,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -202,18 +197,14 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, location='json')
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
node_id=node_id,
|
||||
args=args,
|
||||
streaming=True
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
@ -227,6 +218,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -239,18 +231,14 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, location='json')
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
node_id=node_id,
|
||||
args=args,
|
||||
streaming=True
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
@ -264,6 +252,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class DraftWorkflowRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -276,19 +265,15 @@ class DraftWorkflowRunApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
@ -311,12 +296,10 @@ class WorkflowTaskStopApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {
|
||||
"result": "success"
|
||||
}
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class DraftWorkflowNodeRunApi(Resource):
|
||||
|
@ -332,24 +315,20 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model,
|
||||
node_id=node_id,
|
||||
user_inputs=args.get('inputs'),
|
||||
account=current_user
|
||||
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
class PublishedWorkflowApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -362,7 +341,7 @@ class PublishedWorkflowApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
# fetch published workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
||||
|
@ -381,14 +360,11 @@ class PublishedWorkflowApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": TimestampField().format(workflow.created_at)
|
||||
}
|
||||
return {"result": "success", "created_at": TimestampField().format(workflow.created_at)}
|
||||
|
||||
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
|
@ -403,7 +379,7 @@ class DefaultBlockConfigsApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
# Get default block configs
|
||||
workflow_service = WorkflowService()
|
||||
return workflow_service.get_default_block_configs()
|
||||
|
@ -421,24 +397,21 @@ class DefaultBlockConfigApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('q', type=str, location='args')
|
||||
parser.add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
filters = None
|
||||
if args.get('q'):
|
||||
if args.get("q"):
|
||||
try:
|
||||
filters = json.loads(args.get('q'))
|
||||
filters = json.loads(args.get("q"))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError('Invalid filters')
|
||||
raise ValueError("Invalid filters")
|
||||
|
||||
# Get default block configs
|
||||
workflow_service = WorkflowService()
|
||||
return workflow_service.get_default_block_config(
|
||||
node_type=block_type,
|
||||
filters=filters
|
||||
)
|
||||
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
class ConvertToWorkflowApi(Resource):
|
||||
|
@ -455,40 +428,43 @@ class ConvertToWorkflowApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
args = {}
|
||||
|
||||
# convert to workflow mode
|
||||
workflow_service = WorkflowService()
|
||||
new_app_model = workflow_service.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
args=args
|
||||
)
|
||||
new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args)
|
||||
|
||||
# return app id
|
||||
return {
|
||||
'new_app_id': new_app_model.id,
|
||||
"new_app_id": new_app_model.id,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft')
|
||||
api.add_resource(DraftWorkflowImportApi, '/apps/<uuid:app_id>/workflows/draft/import')
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run')
|
||||
api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
|
||||
api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop')
|
||||
api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run')
|
||||
api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run')
|
||||
api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run')
|
||||
api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/publish')
|
||||
api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
|
||||
api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs'
|
||||
'/<string:block_type>')
|
||||
api.add_resource(ConvertToWorkflowApi, '/apps/<uuid:app_id>/convert-to-workflow')
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
|
||||
)
|
||||
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
|
||||
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
|
||||
)
|
||||
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")
|
||||
|
|
|
@ -22,20 +22,19 @@ class WorkflowAppLogApi(Resource):
|
|||
Get workflow app logs
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('keyword', type=str, location='args')
|
||||
parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args')
|
||||
parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
app_model=app_model,
|
||||
args=args
|
||||
app_model=app_model, args=args
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
api.add_resource(WorkflowAppLogApi, '/apps/<uuid:app_id>/workflow-app-logs')
|
||||
api.add_resource(WorkflowAppLogApi, "/apps/<uuid:app_id>/workflow-app-logs")
|
||||
|
|
|
@ -28,15 +28,12 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||
Get advanced chat app workflow run list
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
|
||||
app_model=app_model,
|
||||
args=args
|
||||
)
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
|
||||
|
||||
return result
|
||||
|
||||
|
@ -52,15 +49,12 @@ class WorkflowRunListApi(Resource):
|
|||
Get workflow run list
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_workflow_runs(
|
||||
app_model=app_model,
|
||||
args=args
|
||||
)
|
||||
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
|
||||
|
||||
return result
|
||||
|
||||
|
@ -98,12 +92,10 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
|||
workflow_run_service = WorkflowRunService()
|
||||
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
|
||||
|
||||
return {
|
||||
'data': node_executions
|
||||
}
|
||||
return {"data": node_executions}
|
||||
|
||||
|
||||
api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps/<uuid:app_id>/advanced-chat/workflow-runs')
|
||||
api.add_resource(WorkflowRunListApi, '/apps/<uuid:app_id>/workflow-runs')
|
||||
api.add_resource(WorkflowRunDetailApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>')
|
||||
api.add_resource(WorkflowRunNodeExecutionListApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions')
|
||||
api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
api.add_resource(WorkflowRunListApi, "/apps/<uuid:app_id>/workflow-runs")
|
||||
api.add_resource(WorkflowRunDetailApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
|
||||
api.add_resource(WorkflowRunNodeExecutionListApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
|
|
|
@ -26,56 +26,56 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = '''
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
||||
"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
sql_query += " GROUP BY date order by date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'runs': i.runs
|
||||
})
|
||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
|
@ -86,56 +86,56 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = '''
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
||||
"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
sql_query += " GROUP BY date order by date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'terminal_count': i.terminal_count
|
||||
})
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
|
@ -146,58 +146,63 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = '''
|
||||
sql_query = """
|
||||
SELECT
|
||||
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) as token_count
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
||||
"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
sql_query += " GROUP BY date order by date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'token_count': i.token_count,
|
||||
})
|
||||
response_data.append(
|
||||
{
|
||||
"date": str(i.date),
|
||||
"token_count": i.token_count,
|
||||
}
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
@setup_required
|
||||
|
@ -208,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
|
@ -229,50 +234,54 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||
GROUP BY date, c.created_by) sub
|
||||
GROUP BY sub.date
|
||||
"""
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start')
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace('{{start}}', '')
|
||||
sql_query = sql_query.replace("{{start}}", "")
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end')
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace('{{end}}', '')
|
||||
sql_query = sql_query.replace("{{end}}", "")
|
||||
|
||||
response_data = []
|
||||
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'interactions': float(i.interactions.quantize(Decimal('0.01')))
|
||||
})
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
api.add_resource(WorkflowDailyRunsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-conversations')
|
||||
api.add_resource(WorkflowDailyTerminalsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-terminals')
|
||||
api.add_resource(WorkflowDailyTokenCostStatistic, '/apps/<uuid:app_id>/workflow/statistics/token-costs')
|
||||
api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps/<uuid:app_id>/workflow/statistics/average-app-interactions')
|
||||
|
||||
api.add_resource(WorkflowDailyRunsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
api.add_resource(WorkflowDailyTerminalsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
||||
api.add_resource(WorkflowDailyTokenCostStatistic, "/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
||||
api.add_resource(
|
||||
WorkflowAverageAppInteractionStatistic, "/apps/<uuid:app_id>/workflow/statistics/average-app-interactions"
|
||||
)
|
||||
|
|
|
@ -8,24 +8,23 @@ from libs.login import current_user
|
|||
from models.model import App, AppMode
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *,
|
||||
mode: Union[AppMode, list[AppMode]] = None):
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if not kwargs.get('app_id'):
|
||||
raise ValueError('missing app_id in path parameters')
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
app_id = kwargs.get('app_id')
|
||||
app_id = kwargs.get("app_id")
|
||||
app_id = str(app_id)
|
||||
|
||||
del kwargs['app_id']
|
||||
del kwargs["app_id"]
|
||||
|
||||
app_model = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
@ -44,9 +43,10 @@ def get_app_model(view: Optional[Callable] = None, *,
|
|||
mode_values = {m.value for m in modes}
|
||||
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
|
||||
|
||||
kwargs['app_model'] = app_model
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
|
|
|
@ -17,60 +17,61 @@ from services.account_service import RegisterService
|
|||
class ActivateCheckApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args')
|
||||
parser.add_argument('email', type=email, required=False, nullable=True, location='args')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args")
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="args")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
workspaceId = args['workspace_id']
|
||||
reg_email = args['email']
|
||||
token = args['token']
|
||||
workspaceId = args["workspace_id"]
|
||||
reg_email = args["email"]
|
||||
token = args["token"]
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
|
||||
return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None}
|
||||
return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None}
|
||||
|
||||
|
||||
class ActivateApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('email', type=email, required=False, nullable=True, location='json')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
|
||||
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
|
||||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
|
||||
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
|
||||
|
||||
account = invitation['account']
|
||||
account.name = args['name']
|
||||
account = invitation["account"]
|
||||
account.name = args["name"]
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(args['password'], salt)
|
||||
password_hashed = hash_password(args["password"], salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
account.interface_language = args['interface_language']
|
||||
account.timezone = args['timezone']
|
||||
account.interface_theme = 'light'
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(ActivateCheckApi, '/activate/check')
|
||||
api.add_resource(ActivateApi, '/activate')
|
||||
api.add_resource(ActivateCheckApi, "/activate/check")
|
||||
api.add_resource(ActivateApi, "/activate")
|
||||
|
|
|
@ -19,18 +19,19 @@ class ApiKeyAuthDataSource(Resource):
|
|||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
'sources': [{
|
||||
'id': data_source_api_key_binding.id,
|
||||
'category': data_source_api_key_binding.category,
|
||||
'provider': data_source_api_key_binding.provider,
|
||||
'disabled': data_source_api_key_binding.disabled,
|
||||
'created_at': int(data_source_api_key_binding.created_at.timestamp()),
|
||||
'updated_at': int(data_source_api_key_binding.updated_at.timestamp()),
|
||||
}
|
||||
for data_source_api_key_binding in
|
||||
data_source_api_key_bindings]
|
||||
"sources": [
|
||||
{
|
||||
"id": data_source_api_key_binding.id,
|
||||
"category": data_source_api_key_binding.category,
|
||||
"provider": data_source_api_key_binding.provider,
|
||||
"disabled": data_source_api_key_binding.disabled,
|
||||
"created_at": int(data_source_api_key_binding.created_at.timestamp()),
|
||||
"updated_at": int(data_source_api_key_binding.updated_at.timestamp()),
|
||||
}
|
||||
for data_source_api_key_binding in data_source_api_key_bindings
|
||||
]
|
||||
}
|
||||
return {'sources': []}
|
||||
return {"sources": []}
|
||||
|
||||
|
||||
class ApiKeyAuthDataSourceBinding(Resource):
|
||||
|
@ -42,16 +43,16 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
|
@ -65,9 +66,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
|||
|
||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
|
||||
api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
|
||||
api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')
|
||||
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
|
||||
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
|
||||
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")
|
||||
|
|
|
@ -17,13 +17,13 @@ from ..wraps import account_initialization_required
|
|||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID,
|
||||
client_secret=dify_config.NOTION_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion')
|
||||
notion_oauth = NotionOAuth(
|
||||
client_id=dify_config.NOTION_CLIENT_ID,
|
||||
client_secret=dify_config.NOTION_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion",
|
||||
)
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'notion': notion_oauth
|
||||
}
|
||||
OAUTH_PROVIDERS = {"notion": notion_oauth}
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
|
@ -37,18 +37,16 @@ class OAuthDataSource(Resource):
|
|||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
print(vars(oauth_provider))
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if dify_config.NOTION_INTEGRATION_TYPE == 'internal':
|
||||
return {"error": "Invalid provider"}, 400
|
||||
if dify_config.NOTION_INTEGRATION_TYPE == "internal":
|
||||
internal_secret = dify_config.NOTION_INTERNAL_SECRET
|
||||
if not internal_secret:
|
||||
return {'error': 'Internal secret is not set'},
|
||||
return ({"error": "Internal secret is not set"},)
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return { 'data': '' }
|
||||
return {"data": ""}
|
||||
else:
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return { 'data': auth_url }, 200
|
||||
|
||||
|
||||
return {"data": auth_url}, 200
|
||||
|
||||
|
||||
class OAuthDataSourceCallback(Resource):
|
||||
|
@ -57,18 +55,18 @@ class OAuthDataSourceCallback(Resource):
|
|||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if 'code' in request.args:
|
||||
code = request.args.get('code')
|
||||
return {"error": "Invalid provider"}, 400
|
||||
if "code" in request.args:
|
||||
code = request.args.get("code")
|
||||
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}")
|
||||
elif "error" in request.args:
|
||||
error = request.args.get("error")
|
||||
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}")
|
||||
else:
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied")
|
||||
|
||||
|
||||
class OAuthDataSourceBinding(Resource):
|
||||
def get(self, provider: str):
|
||||
|
@ -76,17 +74,18 @@ class OAuthDataSourceBinding(Resource):
|
|||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if 'code' in request.args:
|
||||
code = request.args.get('code')
|
||||
return {"error": "Invalid provider"}, 400
|
||||
if "code" in request.args:
|
||||
code = request.args.get("code")
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}"
|
||||
)
|
||||
return {"error": "OAuth data source process failed"}, 400
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class OAuthDataSourceSync(Resource):
|
||||
|
@ -100,18 +99,17 @@ class OAuthDataSourceSync(Resource):
|
|||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {"error": "OAuth data source process failed"}, 400
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')
|
||||
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
|
||||
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
|
||||
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
|
||||
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
|
||||
|
|
|
@ -2,31 +2,30 @@ from libs.exception import BaseHTTPException
|
|||
|
||||
|
||||
class ApiKeyAuthFailedError(BaseHTTPException):
|
||||
error_code = 'auth_failed'
|
||||
error_code = "auth_failed"
|
||||
description = "{message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class InvalidEmailError(BaseHTTPException):
|
||||
error_code = 'invalid_email'
|
||||
error_code = "invalid_email"
|
||||
description = "The email address is not valid."
|
||||
code = 400
|
||||
|
||||
|
||||
class PasswordMismatchError(BaseHTTPException):
|
||||
error_code = 'password_mismatch'
|
||||
error_code = "password_mismatch"
|
||||
description = "The passwords do not match."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidTokenError(BaseHTTPException):
|
||||
error_code = 'invalid_or_expired_token'
|
||||
error_code = "invalid_or_expired_token"
|
||||
description = "The token is invalid or has expired."
|
||||
code = 400
|
||||
|
||||
|
||||
class PasswordResetRateLimitExceededError(BaseHTTPException):
|
||||
error_code = 'password_reset_rate_limit_exceeded'
|
||||
error_code = "password_reset_rate_limit_exceeded"
|
||||
description = "Password reset rate limit exceeded. Try again later."
|
||||
code = 429
|
||||
|
||||
|
|
|
@ -21,14 +21,13 @@ from services.errors.account import RateLimitExceededError
|
|||
|
||||
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('email', type=str, required=True, location='json')
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
email = args['email']
|
||||
email = args["email"]
|
||||
|
||||
if not email_validate(email):
|
||||
raise InvalidEmailError()
|
||||
|
@ -49,38 +48,36 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
|
||||
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
token = args['token']
|
||||
token = args["token"]
|
||||
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
|
||||
if reset_data is None:
|
||||
return {'is_valid': False, 'email': None}
|
||||
return {'is_valid': True, 'email': reset_data.get('email')}
|
||||
return {"is_valid": False, "email": None}
|
||||
return {"is_valid": True, "email": reset_data.get("email")}
|
||||
|
||||
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
new_password = args['new_password']
|
||||
password_confirm = args['password_confirm']
|
||||
new_password = args["new_password"]
|
||||
password_confirm = args["password_confirm"]
|
||||
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
raise PasswordMismatchError()
|
||||
|
||||
token = args['token']
|
||||
token = args["token"]
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
|
||||
if reset_data is None:
|
||||
|
@ -94,14 +91,14 @@ class ForgotPasswordResetApi(Resource):
|
|||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account = Account.query.filter_by(email=reset_data.get('email')).first()
|
||||
account = Account.query.filter_by(email=reset_data.get("email")).first()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
|
||||
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
|
||||
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
|
|
|
@ -20,37 +20,39 @@ class LoginApi(Resource):
|
|||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('email', type=email, required=True, location='json')
|
||||
parser.add_argument('password', type=valid_password, required=True, location='json')
|
||||
parser.add_argument('remember_me', type=bool, required=False, default=False, location='json')
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# todo: Verify the recaptcha
|
||||
|
||||
try:
|
||||
account = AccountService.authenticate(args['email'], args['password'])
|
||||
account = AccountService.authenticate(args["email"], args["password"])
|
||||
except services.errors.account.AccountLoginError as e:
|
||||
return {'code': 'unauthorized', 'message': str(e)}, 401
|
||||
return {"code": "unauthorized", "message": str(e)}, 401
|
||||
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
|
||||
return {
|
||||
"result": "fail",
|
||||
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||
}
|
||||
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
|
||||
return {'result': 'success', 'data': token}
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
class LogoutApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
token = request.headers.get('Authorization', '').split(' ')[1]
|
||||
token = request.headers.get("Authorization", "").split(" ")[1]
|
||||
AccountService.logout(account=account, token=token)
|
||||
flask_login.logout_user()
|
||||
return {'result': 'success'}
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ResetPasswordApi(Resource):
|
||||
|
@ -80,11 +82,11 @@ class ResetPasswordApi(Resource):
|
|||
# 'subject': 'Reset your Dify password',
|
||||
# 'html': """
|
||||
# <p>Dear User,</p>
|
||||
# <p>The Dify team has generated a new password for you, details as follows:</p>
|
||||
# <p>The Dify team has generated a new password for you, details as follows:</p>
|
||||
# <p><strong>{new_password}</strong></p>
|
||||
# <p>Please change your password to log in as soon as possible.</p>
|
||||
# <p>Regards,</p>
|
||||
# <p>The Dify Team</p>
|
||||
# <p>The Dify Team</p>
|
||||
# """
|
||||
# }
|
||||
|
||||
|
@ -101,8 +103,8 @@ class ResetPasswordApi(Resource):
|
|||
# # handle error
|
||||
# pass
|
||||
|
||||
return {'result': 'success'}
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(LoginApi, '/login')
|
||||
api.add_resource(LogoutApi, '/logout')
|
||||
api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LogoutApi, "/logout")
|
||||
|
|
|
@ -25,7 +25,7 @@ def get_oauth_providers():
|
|||
github_oauth = GitHubOAuth(
|
||||
client_id=dify_config.GITHUB_CLIENT_ID,
|
||||
client_secret=dify_config.GITHUB_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github",
|
||||
)
|
||||
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
|
||||
google_oauth = None
|
||||
|
@ -33,10 +33,10 @@ def get_oauth_providers():
|
|||
google_oauth = GoogleOAuth(
|
||||
client_id=dify_config.GOOGLE_CLIENT_ID,
|
||||
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
|
||||
)
|
||||
|
||||
OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
|
||||
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth}
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
|
@ -47,7 +47,7 @@ class OAuthLogin(Resource):
|
|||
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
||||
print(vars(oauth_provider))
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
return {"error": "Invalid provider"}, 400
|
||||
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return redirect(auth_url)
|
||||
|
@ -59,20 +59,20 @@ class OAuthCallback(Resource):
|
|||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
return {"error": "Invalid provider"}, 400
|
||||
|
||||
code = request.args.get('code')
|
||||
code = request.args.get("code")
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
|
||||
return {'error': 'OAuth process failed'}, 400
|
||||
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
account = _generate_account(provider, user_info)
|
||||
# Check account status
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
return {'error': 'Account is banned or closed.'}, 403
|
||||
return {"error": "Account is banned or closed."}, 403
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
|
@ -83,7 +83,7 @@ class OAuthCallback(Resource):
|
|||
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
|
@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
|
||||
if not account:
|
||||
# Create account
|
||||
account_name = user_info.name if user_info.name else 'Dify'
|
||||
account_name = user_info.name if user_info.name else "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
)
|
||||
|
@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
return account
|
||||
|
||||
|
||||
api.add_resource(OAuthLogin, '/oauth/login/<provider>')
|
||||
api.add_resource(OAuthCallback, '/oauth/authorize/<provider>')
|
||||
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
|
||||
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")
|
||||
|
|
|
@ -9,28 +9,24 @@ from services.billing_service import BillingService
|
|||
|
||||
|
||||
class Subscription(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
|
||||
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
args = parser.parse_args()
|
||||
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
|
||||
return BillingService.get_subscription(args['plan'],
|
||||
args['interval'],
|
||||
current_user.email,
|
||||
current_user.current_tenant_id)
|
||||
return BillingService.get_subscription(
|
||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
||||
)
|
||||
|
||||
|
||||
class Invoices(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -40,5 +36,5 @@ class Invoices(Resource):
|
|||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
|
||||
|
||||
api.add_resource(Subscription, '/billing/subscription')
|
||||
api.add_resource(Invoices, '/billing/invoices')
|
||||
api.add_resource(Subscription, "/billing/subscription")
|
||||
api.add_resource(Invoices, "/billing/invoices")
|
||||
|
|
|
@ -22,19 +22,22 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task
|
|||
|
||||
|
||||
class DataSourceApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.disabled == False
|
||||
).all()
|
||||
data_source_integrates = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
base_url = request.url_root.rstrip('/')
|
||||
base_url = request.url_root.rstrip("/")
|
||||
data_source_oauth_base_path = "/console/api/oauth/data-source"
|
||||
providers = ["notion"]
|
||||
|
||||
|
@ -44,26 +47,30 @@ class DataSourceApi(Resource):
|
|||
existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
|
||||
if existing_integrates:
|
||||
for existing_integrate in list(existing_integrates):
|
||||
integrate_data.append({
|
||||
'id': existing_integrate.id,
|
||||
'provider': provider,
|
||||
'created_at': existing_integrate.created_at,
|
||||
'is_bound': True,
|
||||
'disabled': existing_integrate.disabled,
|
||||
'source_info': existing_integrate.source_info,
|
||||
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
|
||||
})
|
||||
integrate_data.append(
|
||||
{
|
||||
"id": existing_integrate.id,
|
||||
"provider": provider,
|
||||
"created_at": existing_integrate.created_at,
|
||||
"is_bound": True,
|
||||
"disabled": existing_integrate.disabled,
|
||||
"source_info": existing_integrate.source_info,
|
||||
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
|
||||
}
|
||||
)
|
||||
else:
|
||||
integrate_data.append({
|
||||
'id': None,
|
||||
'provider': provider,
|
||||
'created_at': None,
|
||||
'source_info': None,
|
||||
'is_bound': False,
|
||||
'disabled': None,
|
||||
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
|
||||
})
|
||||
return {'data': integrate_data}, 200
|
||||
integrate_data.append(
|
||||
{
|
||||
"id": None,
|
||||
"provider": provider,
|
||||
"created_at": None,
|
||||
"source_info": None,
|
||||
"is_bound": False,
|
||||
"disabled": None,
|
||||
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
|
||||
}
|
||||
)
|
||||
return {"data": integrate_data}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -71,92 +78,82 @@ class DataSourceApi(Resource):
|
|||
def patch(self, binding_id, action):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
data_source_binding = DataSourceOauthBinding.query.filter_by(
|
||||
id=binding_id
|
||||
).first()
|
||||
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
|
||||
if data_source_binding is None:
|
||||
raise NotFound('Data source binding not found.')
|
||||
raise NotFound("Data source binding not found.")
|
||||
# enable binding
|
||||
if action == 'enable':
|
||||
if action == "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError('Data source is not disabled.')
|
||||
raise ValueError("Data source is not disabled.")
|
||||
# disable binding
|
||||
if action == 'disable':
|
||||
if action == "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError('Data source is disabled.')
|
||||
return {'result': 'success'}, 200
|
||||
raise ValueError("Data source is disabled.")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DataSourceNotionListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
dataset_id = request.args.get('dataset_id', default=None, type=str)
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
exist_page_ids = []
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
if dataset.data_source_type != 'notion_import':
|
||||
raise ValueError('Dataset is not notion type.')
|
||||
raise NotFound("Dataset not found.")
|
||||
if dataset.data_source_type != "notion_import":
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type='notion_import',
|
||||
enabled=True
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info['notion_page_id'])
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = DataSourceOauthBinding.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider='notion',
|
||||
disabled=False
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {
|
||||
'notion_info': []
|
||||
}, 200
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info['pages']
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page['page_id'] in exist_page_ids:
|
||||
page['is_bound'] = True
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page['is_bound'] = False
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
'workspace_name': source_info['workspace_name'],
|
||||
'workspace_icon': source_info['workspace_icon'],
|
||||
'workspace_id': source_info['workspace_id'],
|
||||
'pages': pages,
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {
|
||||
'notion_info': pre_import_info_list
|
||||
}, 200
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
|
||||
|
||||
class DataSourceNotionApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -166,64 +163,67 @@ class DataSourceNotionApi(Resource):
|
|||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise NotFound('Data source binding not found.')
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
tenant_id=current_user.current_tenant_id
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
|
||||
text_docs = extractor.extract()
|
||||
return {
|
||||
'content': "\n".join([doc.page_content for doc in text_docs])
|
||||
}, 200
|
||||
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
notion_info_list = args['notion_info_list']
|
||||
notion_info_list = args["notion_info_list"]
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'])
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
extract_settings,
|
||||
args["process_rule"],
|
||||
args["doc_form"],
|
||||
args["doc_language"],
|
||||
)
|
||||
return response, 200
|
||||
|
||||
|
||||
class DataSourceNotionDatasetSyncApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -240,7 +240,6 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
|||
|
||||
|
||||
class DataSourceNotionDocumentSyncApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -258,10 +257,14 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
|||
return 200
|
||||
|
||||
|
||||
api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
|
||||
api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
|
||||
api.add_resource(DataSourceNotionApi,
|
||||
'/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview',
|
||||
'/datasets/notion-indexing-estimate')
|
||||
api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets/<uuid:dataset_id>/notion/sync')
|
||||
api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync')
|
||||
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
|
||||
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
|
||||
api.add_resource(
|
||||
DataSourceNotionApi,
|
||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
|
||||
api.add_resource(
|
||||
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
|
||||
)
|
||||
|
|
|
@ -24,52 +24,47 @@ from fields.app_fields import related_app_list
|
|||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
||||
from models.model import ApiToken, UploadFile
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError('Name must be between 1 to 40 characters.')
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError('Description cannot exceed 400 characters.')
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
ids = request.args.getlist('ids')
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
tag_ids = request.args.getlist('tag_ids')
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
current_user.current_tenant_id, current_user, search, tag_ids)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
|
||||
embedding_models = configurations.get_models(
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
only_active=True
|
||||
)
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
model_names = []
|
||||
for embedding_model in embedding_models:
|
||||
|
@ -77,28 +72,22 @@ class DatasetListApi(Resource):
|
|||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item['indexing_technique'] == 'high_quality':
|
||||
if item["indexing_technique"] == "high_quality":
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item['embedding_available'] = True
|
||||
item["embedding_available"] = True
|
||||
else:
|
||||
item['embedding_available'] = False
|
||||
item["embedding_available"] = False
|
||||
else:
|
||||
item['embedding_available'] = True
|
||||
item["embedding_available"] = True
|
||||
|
||||
if item.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
|
||||
item.update({'partial_member_list': part_users_list})
|
||||
if item.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"])
|
||||
item.update({"partial_member_list": part_users_list})
|
||||
else:
|
||||
item.update({'partial_member_list': []})
|
||||
item.update({"partial_member_list": []})
|
||||
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
|
@ -106,13 +95,21 @@ class DatasetListApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
|
@ -122,9 +119,10 @@ class DatasetListApi(Resource):
|
|||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name=args['name'],
|
||||
indexing_technique=args['indexing_technique'],
|
||||
account=current_user
|
||||
name=args["name"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
@ -142,42 +140,36 @@ class DatasetApi(Resource):
|
|||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(
|
||||
dataset, current_user)
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
if data.get('permission') == 'partial_members':
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
|
||||
embedding_models = configurations.get_models(
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
only_active=True
|
||||
)
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
model_names = []
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data['indexing_technique'] == 'high_quality':
|
||||
if data["indexing_technique"] == "high_quality":
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
data['embedding_available'] = True
|
||||
data["embedding_available"] = True
|
||||
else:
|
||||
data['embedding_available'] = False
|
||||
data["embedding_available"] = False
|
||||
else:
|
||||
data['embedding_available'] = True
|
||||
data["embedding_available"] = True
|
||||
|
||||
if data.get('permission') == 'partial_members':
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
return data, 200
|
||||
|
||||
|
@ -191,42 +183,49 @@ class DatasetApi(Resource):
|
|||
raise NotFound("Dataset not found.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False,
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('description',
|
||||
location='json', store_missing=False,
|
||||
type=_validate_description_length)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
|
||||
)
|
||||
parser.add_argument('embedding_model', type=str,
|
||||
location='json', help='Invalid embedding model.')
|
||||
parser.add_argument('embedding_model_provider', type=str,
|
||||
location='json', help='Invalid embedding model provider.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
help="Invalid permission.",
|
||||
)
|
||||
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||
parser.add_argument(
|
||||
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
# check embedding model setting
|
||||
if data.get('indexing_technique') == 'high_quality':
|
||||
DatasetService.check_embedding_model_setting(dataset.tenant_id,
|
||||
data.get('embedding_model_provider'),
|
||||
data.get('embedding_model')
|
||||
)
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, data.get('permission'), data.get('partial_member_list')
|
||||
current_user, dataset, data.get("permission"), data.get("partial_member_list")
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(
|
||||
dataset_id_str, args, current_user)
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
|
||||
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
@ -234,16 +233,19 @@ class DatasetApi(Resource):
|
|||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get('partial_member_list')
|
||||
tenant_id, dataset_id_str, data.get("partial_member_list")
|
||||
)
|
||||
# clear partial member list when permission is only_me or all_team_members
|
||||
elif data.get('permission') == 'only_me' or data.get('permission') == 'all_team_members':
|
||||
elif (
|
||||
data.get("permission") == DatasetPermissionEnum.ONLY_ME
|
||||
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
|
||||
):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({'partial_member_list': partial_member_list})
|
||||
result_data.update({"partial_member_list": partial_member_list})
|
||||
|
||||
return result_data, 200
|
||||
|
||||
|
@ -260,12 +262,13 @@ class DatasetApi(Resource):
|
|||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return {'result': 'success'}, 204
|
||||
return {"result": "success"}, 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
raise DatasetInUseError()
|
||||
|
||||
|
||||
class DatasetUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -274,10 +277,10 @@ class DatasetUseCheckApi(Resource):
|
|||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
|
||||
return {'is_using': dataset_is_using}, 200
|
||||
return {"is_using": dataset_is_using}, 200
|
||||
|
||||
|
||||
class DatasetQueryApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -292,51 +295,53 @@ class DatasetQueryApi(Resource):
|
|||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
|
||||
dataset_queries, total = DatasetService.get_dataset_queries(
|
||||
dataset_id=dataset.id,
|
||||
page=page,
|
||||
per_page=limit
|
||||
)
|
||||
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
|
||||
|
||||
response = {
|
||||
'data': marshal(dataset_queries, dataset_query_detail_fields),
|
||||
'has_more': len(dataset_queries) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
"data": marshal(dataset_queries, dataset_query_detail_fields),
|
||||
"has_more": len(dataset_queries) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
class DatasetIndexingEstimateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
if args['info_list']['data_source_type'] == 'upload_file':
|
||||
file_ids = args['info_list']['file_info_list']['file_ids']
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id.in_(file_ids)
|
||||
).all()
|
||||
if args["info_list"]["data_source_type"] == "upload_file":
|
||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||
file_details = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
@ -344,55 +349,58 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=args['doc_form']
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
notion_info_list = args['info_list']['notion_info_list']
|
||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args['info_list']['data_source_type'] == 'website_crawl':
|
||||
website_info_list = args['info_list']['website_info_list']
|
||||
for url in website_info_list['urls']:
|
||||
elif args["info_list"]["data_source_type"] == "website_crawl":
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
website_info={
|
||||
"provider": website_info_list['provider'],
|
||||
"job_id": website_info_list['job_id'],
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": 'crawl',
|
||||
"only_main_content": website_info_list['only_main_content']
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
raise ValueError("Data source type not support")
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
extract_settings,
|
||||
args["process_rule"],
|
||||
args["doc_form"],
|
||||
args["doc_language"],
|
||||
args["dataset_id"],
|
||||
args["indexing_technique"],
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except Exception as e:
|
||||
|
@ -402,7 +410,6 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
|
||||
|
||||
class DatasetRelatedAppListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -426,52 +433,52 @@ class DatasetRelatedAppListApi(Resource):
|
|||
if app_model:
|
||||
related_apps.append(app_model)
|
||||
|
||||
return {
|
||||
'data': related_apps,
|
||||
'total': len(related_apps)
|
||||
}, 200
|
||||
return {"data": related_apps, "total": len(related_apps)}, 200
|
||||
|
||||
|
||||
class DatasetIndexingStatusApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.tenant_id == current_user.current_tenant_id
|
||||
).all()
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
|
||||
.all()
|
||||
)
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
completed_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
).count()
|
||||
total_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
||||
).count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
data = {"data": documents_status}
|
||||
return data
|
||||
|
||||
|
||||
class DatasetApiKeyApi(Resource):
|
||||
max_keys = 10
|
||||
token_prefix = 'dataset-'
|
||||
resource_type = 'dataset'
|
||||
token_prefix = "dataset-"
|
||||
resource_type = "dataset"
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_list)
|
||||
def get(self):
|
||||
keys = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
||||
all()
|
||||
keys = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
|
||||
.all()
|
||||
)
|
||||
return {"items": keys}
|
||||
|
||||
@setup_required
|
||||
|
@ -483,15 +490,17 @@ class DatasetApiKeyApi(Resource):
|
|||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
||||
count()
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
flask_restful.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code='max_keys_exceeded'
|
||||
code="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
|
@ -505,7 +514,7 @@ class DatasetApiKeyApi(Resource):
|
|||
|
||||
|
||||
class DatasetApiDeleteApi(Resource):
|
||||
resource_type = 'dataset'
|
||||
resource_type = "dataset"
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -517,18 +526,23 @@ class DatasetApiDeleteApi(Resource):
|
|||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = db.session.query(ApiToken). \
|
||||
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id). \
|
||||
first()
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(
|
||||
ApiToken.tenant_id == current_user.current_tenant_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if key is None:
|
||||
flask_restful.abort(404, message='API key not found')
|
||||
flask_restful.abort(404, message="API key not found")
|
||||
|
||||
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class DatasetApiBaseUrlApi(Resource):
|
||||
|
@ -537,8 +551,10 @@ class DatasetApiBaseUrlApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
return {
|
||||
'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
|
||||
else request.host_url.rstrip('/')) + '/v1'
|
||||
"api_base_url": (
|
||||
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
|
||||
)
|
||||
+ "/v1"
|
||||
}
|
||||
|
||||
|
||||
|
@ -549,15 +565,26 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
def get(self):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
):
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
|
@ -573,15 +600,27 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
match vector_type:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.PGVECTOR
|
||||
):
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
|
@ -591,7 +630,6 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
|
||||
|
||||
|
||||
class DatasetErrorDocs(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -603,10 +641,7 @@ class DatasetErrorDocs(Resource):
|
|||
raise NotFound("Dataset not found.")
|
||||
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
||||
|
||||
return {
|
||||
'data': [marshal(item, document_status_fields) for item in results],
|
||||
'total': len(results)
|
||||
}, 200
|
||||
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
|
||||
|
||||
|
||||
class DatasetPermissionUserListApi(Resource):
|
||||
|
@ -626,21 +661,21 @@ class DatasetPermissionUserListApi(Resource):
|
|||
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
||||
return {
|
||||
'data': partial_members_list,
|
||||
"data": partial_members_list,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
||||
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
|
||||
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')
|
||||
api.add_resource(DatasetListApi, "/datasets")
|
||||
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
|
||||
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
|
||||
api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries")
|
||||
api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs")
|
||||
api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate")
|
||||
api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps")
|
||||
api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status")
|
||||
api.add_resource(DatasetApiKeyApi, "/datasets/api-keys")
|
||||
api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>")
|
||||
api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
|
||||
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
|
||||
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
@ -50,37 +50,33 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=str, default=None, location='args')
|
||||
parser.add_argument('limit', type=int, default=20, location='args')
|
||||
parser.add_argument('status', type=str,
|
||||
action='append', default=[], location='args')
|
||||
parser.add_argument('hit_count_gte', type=int,
|
||||
default=None, location='args')
|
||||
parser.add_argument('enabled', type=str, default='all', location='args')
|
||||
parser.add_argument('keyword', type=str, default=None, location='args')
|
||||
parser.add_argument("last_id", type=str, default=None, location="args")
|
||||
parser.add_argument("limit", type=int, default=20, location="args")
|
||||
parser.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||
parser.add_argument("enabled", type=str, default="all", location="args")
|
||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
last_id = args['last_id']
|
||||
limit = min(args['limit'], 100)
|
||||
status_list = args['status']
|
||||
hit_count_gte = args['hit_count_gte']
|
||||
keyword = args['keyword']
|
||||
last_id = args["last_id"]
|
||||
limit = min(args["limit"], 100)
|
||||
status_list = args["status"]
|
||||
hit_count_gte = args["hit_count_gte"]
|
||||
keyword = args["keyword"]
|
||||
|
||||
query = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
)
|
||||
|
||||
if last_id is not None:
|
||||
last_segment = db.session.get(DocumentSegment, str(last_id))
|
||||
if last_segment:
|
||||
query = query.filter(
|
||||
DocumentSegment.position > last_segment.position)
|
||||
query = query.filter(DocumentSegment.position > last_segment.position)
|
||||
else:
|
||||
return {'data': [], 'has_more': False, 'limit': limit}, 200
|
||||
return {"data": [], "has_more": False, "limit": limit}, 200
|
||||
|
||||
if status_list:
|
||||
query = query.filter(DocumentSegment.status.in_(status_list))
|
||||
|
@ -89,12 +85,12 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
|
||||
|
||||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
|
||||
if args['enabled'].lower() != 'all':
|
||||
if args['enabled'].lower() == 'true':
|
||||
if args["enabled"].lower() != "all":
|
||||
if args["enabled"].lower() == "true":
|
||||
query = query.filter(DocumentSegment.enabled == True)
|
||||
elif args['enabled'].lower() == 'false':
|
||||
elif args["enabled"].lower() == "false":
|
||||
query = query.filter(DocumentSegment.enabled == False)
|
||||
|
||||
total = query.count()
|
||||
|
@ -106,11 +102,11 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
segments = segments[:-1]
|
||||
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form,
|
||||
'has_more': has_more,
|
||||
'limit': limit,
|
||||
'total': total
|
||||
"data": marshal(segments, segment_fields),
|
||||
"doc_form": document.doc_form,
|
||||
"has_more": has_more,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
}, 200
|
||||
|
||||
|
||||
|
@ -118,12 +114,12 @@ class DatasetDocumentSegmentApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
def patch(self, dataset_id, segment_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
|
@ -134,7 +130,7 @@ class DatasetDocumentSegmentApi(Resource):
|
|||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
|
@ -142,32 +138,32 @@ class DatasetDocumentSegmentApi(Resource):
|
|||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
if segment.status != 'completed':
|
||||
raise NotFound('Segment is not completed, enable or disable function is not allowed')
|
||||
if segment.status != "completed":
|
||||
raise NotFound("Segment is not completed, enable or disable function is not allowed")
|
||||
|
||||
document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
|
||||
document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
|
||||
cache_result = redis_client.get(document_indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise InvalidActionError("Document is being indexed, please try again later")
|
||||
|
||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
||||
indexing_cache_key = "segment_{}_indexing".format(segment.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise InvalidActionError("Segment is being indexed, please try again later")
|
||||
|
@ -186,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
|
|||
|
||||
enable_segment_to_index_task.delay(segment.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
elif action == "disable":
|
||||
if not segment.enabled:
|
||||
raise InvalidActionError("Segment is already disabled.")
|
||||
|
@ -201,7 +197,7 @@ class DatasetDocumentSegmentApi(Resource):
|
|||
|
||||
disable_segment_from_index_task.delay(segment.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
else:
|
||||
raise InvalidActionError()
|
||||
|
||||
|
@ -210,35 +206,36 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
@cloud_edition_billing_knowledge_limit_check('add_segment')
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
raise NotFound("Document not found.")
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
|
@ -247,37 +244,34 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
|
@ -285,22 +279,22 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
raise NotFound("Segment not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
@ -310,16 +304,13 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -329,22 +320,21 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
raise NotFound("Segment not found.")
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
@ -353,36 +343,36 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
@cloud_edition_billing_knowledge_limit_check('add_segment')
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
raise NotFound("Document not found.")
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
file = request.files["file"]
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
if not file.filename.endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
|
@ -390,51 +380,47 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||
df = pd.read_csv(file)
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
if document.doc_form == 'qa_model':
|
||||
data = {'content': row[0], 'answer': row[1]}
|
||||
if document.doc_form == "qa_model":
|
||||
data = {"content": row[0], "answer": row[1]}
|
||||
else:
|
||||
data = {'content': row[0]}
|
||||
data = {"content": row[0]}
|
||||
result.append(data)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
|
||||
indexing_cache_key = "segment_batch_import_{}".format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, 'waiting')
|
||||
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
|
||||
current_user.current_tenant_id, current_user.id)
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_create_segment_to_index_task.delay(
|
||||
str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
|
||||
)
|
||||
except Exception as e:
|
||||
return {'error': str(e)}, 500
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}, 200
|
||||
return {"error": str(e)}, 500
|
||||
return {"job_id": job_id, "job_status": "waiting"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id):
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
|
||||
indexing_cache_key = "segment_batch_import_{}".format(job_id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': cache_result.decode()
|
||||
}, 200
|
||||
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetDocumentSegmentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
||||
api.add_resource(DatasetDocumentSegmentApi,
|
||||
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
|
||||
api.add_resource(DatasetDocumentSegmentAddApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
|
||||
api.add_resource(DatasetDocumentSegmentUpdateApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
|
||||
api.add_resource(DatasetDocumentSegmentBatchImportApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
|
||||
'/datasets/batch_import_status/<uuid:job_id>')
|
||||
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>")
|
||||
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentUpdateApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentBatchImportApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
|
||||
"/datasets/batch_import_status/<uuid:job_id>",
|
||||
)
|
||||
|
|
|
@ -2,90 +2,90 @@ from libs.exception import BaseHTTPException
|
|||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = 'file_too_large'
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class HighQualityDatasetOnlyError(BaseHTTPException):
|
||||
error_code = 'high_quality_dataset_only'
|
||||
error_code = "high_quality_dataset_only"
|
||||
description = "Current operation only supports 'high-quality' datasets."
|
||||
code = 400
|
||||
|
||||
|
||||
class DatasetNotInitializedError(BaseHTTPException):
|
||||
error_code = 'dataset_not_initialized'
|
||||
error_code = "dataset_not_initialized"
|
||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||
code = 400
|
||||
|
||||
|
||||
class ArchivedDocumentImmutableError(BaseHTTPException):
|
||||
error_code = 'archived_document_immutable'
|
||||
error_code = "archived_document_immutable"
|
||||
description = "The archived document is not editable."
|
||||
code = 403
|
||||
|
||||
|
||||
class DatasetNameDuplicateError(BaseHTTPException):
|
||||
error_code = 'dataset_name_duplicate'
|
||||
error_code = "dataset_name_duplicate"
|
||||
description = "The dataset name already exists. Please modify your dataset name."
|
||||
code = 409
|
||||
|
||||
|
||||
class InvalidActionError(BaseHTTPException):
|
||||
error_code = 'invalid_action'
|
||||
error_code = "invalid_action"
|
||||
description = "Invalid action."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentAlreadyFinishedError(BaseHTTPException):
|
||||
error_code = 'document_already_finished'
|
||||
error_code = "document_already_finished"
|
||||
description = "The document has been processed. Please refresh the page or go to the document details."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentIndexingError(BaseHTTPException):
|
||||
error_code = 'document_indexing'
|
||||
error_code = "document_indexing"
|
||||
description = "The document is being processed and cannot be edited."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidMetadataError(BaseHTTPException):
|
||||
error_code = 'invalid_metadata'
|
||||
error_code = "invalid_metadata"
|
||||
description = "The metadata content is incorrect. Please check and verify."
|
||||
code = 400
|
||||
|
||||
|
||||
class WebsiteCrawlError(BaseHTTPException):
|
||||
error_code = 'crawl_failed'
|
||||
error_code = "crawl_failed"
|
||||
description = "{message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class DatasetInUseError(BaseHTTPException):
|
||||
error_code = 'dataset_in_use'
|
||||
error_code = "dataset_in_use"
|
||||
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
||||
code = 409
|
||||
|
||||
|
||||
class IndexingEstimateError(BaseHTTPException):
|
||||
error_code = 'indexing_estimate_error'
|
||||
error_code = "indexing_estimate_error"
|
||||
description = "Knowledge indexing estimate failed: {message}"
|
||||
code = 500
|
||||
|
|
|
@ -21,7 +21,6 @@ PREVIEW_WORDS_LIMIT = 3000
|
|||
|
||||
|
||||
class FileApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -31,23 +30,22 @@ class FileApi(Resource):
|
|||
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
|
||||
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
return {
|
||||
'file_size_limit': file_size_limit,
|
||||
'batch_count_limit': batch_count_limit,
|
||||
'image_file_size_limit': image_file_size_limit
|
||||
"file_size_limit": file_size_limit,
|
||||
"batch_count_limit": batch_count_limit,
|
||||
"image_file_size_limit": image_file_size_limit,
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check(resource='documents')
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
def post(self):
|
||||
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
file = request.files["file"]
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
|
@ -69,7 +67,7 @@ class FilePreviewApi(Resource):
|
|||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
text = FileService.get_file_preview(file_id)
|
||||
return {'content': text}
|
||||
return {"content": text}
|
||||
|
||||
|
||||
class FileSupportTypeApi(Resource):
|
||||
|
@ -78,10 +76,10 @@ class FileSupportTypeApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
|
||||
return {'allowed_extensions': allowed_extensions}
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
|
||||
return {"allowed_extensions": allowed_extensions}
|
||||
|
||||
|
||||
api.add_resource(FileApi, '/files/upload')
|
||||
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
|
||||
api.add_resource(FileSupportTypeApi, '/files/support-type')
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
|
|
|
@ -29,7 +29,6 @@ from services.hit_testing_service import HitTestingService
|
|||
|
||||
|
||||
class HitTestingApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -46,8 +45,8 @@ class HitTestingApi(Resource):
|
|||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
@ -55,13 +54,13 @@ class HitTestingApi(Resource):
|
|||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args['query'],
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args['retrieval_model'],
|
||||
limit=10
|
||||
retrieval_model=args["retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
@ -73,7 +72,8 @@ class HitTestingApi(Resource):
|
|||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
|
@ -83,4 +83,4 @@ class HitTestingApi(Resource):
|
|||
raise InternalServerError(str(e))
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||
|
|
|
@ -9,16 +9,14 @@ from services.website_service import WebsiteService
|
|||
|
||||
|
||||
class WebsiteCrawlApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('provider', type=str, choices=['firecrawl'],
|
||||
required=True, nullable=True, location='json')
|
||||
parser.add_argument('url', type=str, required=True, nullable=True, location='json')
|
||||
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json")
|
||||
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
WebsiteService.document_create_args_validate(args)
|
||||
# crawl url
|
||||
|
@ -35,15 +33,15 @@ class WebsiteCrawlStatusApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, job_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
|
||||
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
# get crawl status
|
||||
try:
|
||||
result = WebsiteService.get_crawl_status(job_id, args['provider'])
|
||||
result = WebsiteService.get_crawl_status(job_id, args["provider"])
|
||||
except Exception as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
return result, 200
|
||||
|
||||
|
||||
api.add_resource(WebsiteCrawlApi, '/website/crawl')
|
||||
api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')
|
||||
api.add_resource(WebsiteCrawlApi, "/website/crawl")
|
||||
api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/<string:job_id>")
|
||||
|
|
|
@ -2,35 +2,41 @@ from libs.exception import BaseHTTPException
|
|||
|
||||
|
||||
class AlreadySetupError(BaseHTTPException):
|
||||
error_code = 'already_setup'
|
||||
error_code = "already_setup"
|
||||
description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage."
|
||||
code = 403
|
||||
|
||||
|
||||
class NotSetupError(BaseHTTPException):
|
||||
error_code = 'not_setup'
|
||||
description = "Dify has not been initialized and installed yet. " \
|
||||
"Please proceed with the initialization and installation process first."
|
||||
error_code = "not_setup"
|
||||
description = (
|
||||
"Dify has not been initialized and installed yet. "
|
||||
"Please proceed with the initialization and installation process first."
|
||||
)
|
||||
code = 401
|
||||
|
||||
|
||||
class NotInitValidateError(BaseHTTPException):
|
||||
error_code = 'not_init_validated'
|
||||
description = "Init validation has not been completed yet. " \
|
||||
"Please proceed with the init validation process first."
|
||||
error_code = "not_init_validated"
|
||||
description = (
|
||||
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
|
||||
)
|
||||
code = 401
|
||||
|
||||
|
||||
class InitValidateFailedError(BaseHTTPException):
|
||||
error_code = 'init_validate_failed'
|
||||
error_code = "init_validate_failed"
|
||||
description = "Init validation failed. Please check the password and try again."
|
||||
code = 401
|
||||
|
||||
|
||||
class AccountNotLinkTenantError(BaseHTTPException):
|
||||
error_code = 'account_not_link_tenant'
|
||||
error_code = "account_not_link_tenant"
|
||||
description = "Account not link tenant."
|
||||
code = 403
|
||||
|
||||
|
||||
class AlreadyActivateError(BaseHTTPException):
|
||||
error_code = 'already_activate'
|
||||
error_code = "already_activate"
|
||||
description = "Auth Token is invalid or account already activated, please check again."
|
||||
code = 403
|
||||
|
|
|
@ -33,14 +33,10 @@ class ChatAudioApi(InstalledAppResource):
|
|||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
|
||||
file = request.files['file']
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(
|
||||
app_model=app_model,
|
||||
file=file,
|
||||
end_user=None
|
||||
)
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
|
@ -76,30 +72,31 @@ class ChatTextApi(InstalledAppResource):
|
|||
app_model = installed_app.app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
voice=voice,
|
||||
text=text
|
||||
)
|
||||
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
|
@ -127,7 +124,7 @@ class ChatTextApi(InstalledAppResource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
|
||||
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
|
||||
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
|
||||
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
|
||||
# endpoint='installed_app_text_with_message_id')
|
||||
|
|
|
@ -30,33 +30,28 @@ from services.app_generate_service import AppGenerateService
|
|||
|
||||
# define completion api for user
|
||||
class CompletionApi(InstalledAppResource):
|
||||
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != 'completion':
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
@ -85,12 +80,12 @@ class CompletionApi(InstalledAppResource):
|
|||
class CompletionStopApi(InstalledAppResource):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != 'completion':
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class ChatApi(InstalledAppResource):
|
||||
|
@ -101,25 +96,21 @@ class ChatApi(InstalledAppResource):
|
|||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=True
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
@ -154,10 +145,22 @@ class ChatStopApi(InstalledAppResource):
|
|||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(CompletionApi, '/installed-apps/<uuid:installed_app_id>/completion-messages', endpoint='installed_app_completion')
|
||||
api.add_resource(CompletionStopApi, '/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop', endpoint='installed_app_stop_completion')
|
||||
api.add_resource(ChatApi, '/installed-apps/<uuid:installed_app_id>/chat-messages', endpoint='installed_app_chat_completion')
|
||||
api.add_resource(ChatStopApi, '/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop', endpoint='installed_app_stop_chat_completion')
|
||||
api.add_resource(
|
||||
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
|
||||
)
|
||||
api.add_resource(
|
||||
CompletionStopApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_completion",
|
||||
)
|
||||
api.add_resource(
|
||||
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
|
||||
)
|
||||
api.add_resource(
|
||||
ChatStopApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_chat_completion",
|
||||
)
|
||||
|
|
|
@ -16,7 +16,6 @@ from services.web_conversation_service import WebConversationService
|
|||
|
||||
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
|
@ -25,21 +24,21 @@ class ConversationListApi(InstalledAppResource):
|
|||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if 'pinned' in args and args['pinned'] is not None:
|
||||
pinned = True if args['pinned'] == 'true' else False
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
)
|
||||
|
@ -65,7 +64,6 @@ class ConversationApi(InstalledAppResource):
|
|||
|
||||
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
|
@ -76,24 +74,19 @@ class ConversationRenameApi(InstalledAppResource):
|
|||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
parser.add_argument("name", type=str, required=False, location="json")
|
||||
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
current_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
class ConversationPinApi(InstalledAppResource):
|
||||
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
@ -123,8 +116,26 @@ class ConversationUnPinApi(InstalledAppResource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(ConversationRenameApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name', endpoint='installed_app_conversation_rename')
|
||||
api.add_resource(ConversationListApi, '/installed-apps/<uuid:installed_app_id>/conversations', endpoint='installed_app_conversations')
|
||||
api.add_resource(ConversationApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>', endpoint='installed_app_conversation')
|
||||
api.add_resource(ConversationPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin', endpoint='installed_app_conversation_pin')
|
||||
api.add_resource(ConversationUnPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin', endpoint='installed_app_conversation_unpin')
|
||||
api.add_resource(
|
||||
ConversationRenameApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
|
||||
endpoint="installed_app_conversation_rename",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
|
||||
endpoint="installed_app_conversation",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationPinApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
|
||||
endpoint="installed_app_conversation_pin",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationUnPinApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
|
||||
endpoint="installed_app_conversation_unpin",
|
||||
)
|
||||
|
|
|
@ -2,24 +2,24 @@ from libs.exception import BaseHTTPException
|
|||
|
||||
|
||||
class NotCompletionAppError(BaseHTTPException):
|
||||
error_code = 'not_completion_app'
|
||||
error_code = "not_completion_app"
|
||||
description = "Not Completion App"
|
||||
code = 400
|
||||
|
||||
|
||||
class NotChatAppError(BaseHTTPException):
|
||||
error_code = 'not_chat_app'
|
||||
error_code = "not_chat_app"
|
||||
description = "App mode is invalid."
|
||||
code = 400
|
||||
|
||||
|
||||
class NotWorkflowAppError(BaseHTTPException):
|
||||
error_code = 'not_workflow_app'
|
||||
error_code = "not_workflow_app"
|
||||
description = "Only support workflow app."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
|
||||
error_code = 'app_suggested_questions_after_answer_disabled'
|
||||
error_code = "app_suggested_questions_after_answer_disabled"
|
||||
description = "Function Suggested questions after answer disabled."
|
||||
code = 403
|
||||
|
|
|
@ -21,72 +21,72 @@ class InstalledAppsListApi(Resource):
|
|||
@marshal_with(installed_app_list_fields)
|
||||
def get(self):
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
installed_apps = db.session.query(InstalledApp).filter(
|
||||
InstalledApp.tenant_id == current_tenant_id
|
||||
).all()
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
installed_apps = [
|
||||
{
|
||||
'id': installed_app.id,
|
||||
'app': installed_app.app,
|
||||
'app_owner_tenant_id': installed_app.app_owner_tenant_id,
|
||||
'is_pinned': installed_app.is_pinned,
|
||||
'last_used_at': installed_app.last_used_at,
|
||||
'editable': current_user.role in ["owner", "admin"],
|
||||
'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id
|
||||
"id": installed_app.id,
|
||||
"app": installed_app.app,
|
||||
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
|
||||
"is_pinned": installed_app.is_pinned,
|
||||
"last_used_at": installed_app.last_used_at,
|
||||
"editable": current_user.role in ["owner", "admin"],
|
||||
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
|
||||
}
|
||||
for installed_app in installed_apps
|
||||
if installed_app.app is not None
|
||||
]
|
||||
installed_apps.sort(key=lambda app: (-app['is_pinned'],
|
||||
app['last_used_at'] is None,
|
||||
-app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
|
||||
installed_apps.sort(
|
||||
key=lambda app: (
|
||||
-app["is_pinned"],
|
||||
app["last_used_at"] is None,
|
||||
-app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0,
|
||||
)
|
||||
)
|
||||
|
||||
return {'installed_apps': installed_apps}
|
||||
return {"installed_apps": installed_apps}
|
||||
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('app_id', type=str, required=True, help='Invalid app_id')
|
||||
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
||||
args = parser.parse_args()
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
if recommended_app is None:
|
||||
raise NotFound('App not found')
|
||||
raise NotFound("App not found")
|
||||
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
app = db.session.query(App).filter(
|
||||
App.id == args['app_id']
|
||||
).first()
|
||||
app = db.session.query(App).filter(App.id == args["app_id"]).first()
|
||||
|
||||
if app is None:
|
||||
raise NotFound('App not found')
|
||||
raise NotFound("App not found")
|
||||
|
||||
if not app.is_public:
|
||||
raise Forbidden('You can\'t install a non-public app')
|
||||
raise Forbidden("You can't install a non-public app")
|
||||
|
||||
installed_app = InstalledApp.query.filter(and_(
|
||||
InstalledApp.app_id == args['app_id'],
|
||||
InstalledApp.tenant_id == current_tenant_id
|
||||
)).first()
|
||||
installed_app = InstalledApp.query.filter(
|
||||
and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)
|
||||
).first()
|
||||
|
||||
if installed_app is None:
|
||||
# todo: position
|
||||
recommended_app.install_count += 1
|
||||
|
||||
new_installed_app = InstalledApp(
|
||||
app_id=args['app_id'],
|
||||
app_id=args["app_id"],
|
||||
tenant_id=current_tenant_id,
|
||||
app_owner_tenant_id=app.tenant_id,
|
||||
is_pinned=False,
|
||||
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
db.session.add(new_installed_app)
|
||||
db.session.commit()
|
||||
|
||||
return {'message': 'App installed successfully'}
|
||||
return {"message": "App installed successfully"}
|
||||
|
||||
|
||||
class InstalledAppApi(InstalledAppResource):
|
||||
|
@ -94,30 +94,31 @@ class InstalledAppApi(InstalledAppResource):
|
|||
update and delete an installed app
|
||||
use InstalledAppResource to apply default decorators and get installed_app
|
||||
"""
|
||||
|
||||
def delete(self, installed_app):
|
||||
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
||||
raise BadRequest('You can\'t uninstall an app owned by the current tenant')
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
|
||||
db.session.delete(installed_app)
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success', 'message': 'App uninstalled successfully'}
|
||||
return {"result": "success", "message": "App uninstalled successfully"}
|
||||
|
||||
def patch(self, installed_app):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('is_pinned', type=inputs.boolean)
|
||||
parser.add_argument("is_pinned", type=inputs.boolean)
|
||||
args = parser.parse_args()
|
||||
|
||||
commit_args = False
|
||||
if 'is_pinned' in args:
|
||||
installed_app.is_pinned = args['is_pinned']
|
||||
if "is_pinned" in args:
|
||||
installed_app.is_pinned = args["is_pinned"]
|
||||
commit_args = True
|
||||
|
||||
if commit_args:
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success', 'message': 'App info updated successfully'}
|
||||
return {"result": "success", "message": "App info updated successfully"}
|
||||
|
||||
|
||||
api.add_resource(InstalledAppsListApi, '/installed-apps')
|
||||
api.add_resource(InstalledAppApi, '/installed-apps/<uuid:installed_app_id>')
|
||||
api.add_resource(InstalledAppsListApi, "/installed-apps")
|
||||
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")
|
||||
|
|
|
@ -44,19 +44,21 @@ class MessageListApi(InstalledAppResource):
|
|||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, current_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.message.FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
def post(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
|
@ -64,30 +66,32 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {'result': 'success'}
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
def get(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != 'completion':
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
|
||||
parser.add_argument(
|
||||
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_more_like_this(
|
||||
|
@ -95,7 +99,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
|||
user=current_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming
|
||||
streaming=streaming,
|
||||
)
|
||||
return helper.compact_generate_response(response)
|
||||
except MessageNotExistsError:
|
||||
|
@ -128,10 +132,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.EXPLORE
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
|
@ -151,10 +152,22 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {'data': questions}
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
|
||||
api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
|
||||
api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this')
|
||||
api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')
|
||||
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
|
||||
api.add_resource(
|
||||
MessageFeedbackApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
|
||||
endpoint="installed_app_message_feedback",
|
||||
)
|
||||
api.add_resource(
|
||||
MessageMoreLikeThisApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
|
||||
endpoint="installed_app_more_like_this",
|
||||
)
|
||||
api.add_resource(
|
||||
MessageSuggestedQuestionApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="installed_app_suggested_question",
|
||||
)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from flask_restful import fields, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
|
@ -11,33 +10,32 @@ from services.app_service import AppService
|
|||
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
'key': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'type': fields.String,
|
||||
'default': fields.String,
|
||||
'max_length': fields.Integer,
|
||||
'options': fields.List(fields.String)
|
||||
"key": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"default": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
'image_file_size_limit': fields.String
|
||||
}
|
||||
system_parameters_fields = {"image_file_size_limit": fields.String}
|
||||
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'text_to_speech': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
'file_upload': fields.Raw,
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(system_parameters_fields),
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
|
@ -56,30 +54,35 @@ class AppParameterApi(InstalledAppResource):
|
|||
app_model_config = app_model.app_model_config
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get('user_input_form', [])
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return {
|
||||
'opening_statement': features_dict.get('opening_statement'),
|
||||
'suggested_questions': features_dict.get('suggested_questions', []),
|
||||
'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
|
||||
{"enabled": False}),
|
||||
'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
|
||||
'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
|
||||
'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
|
||||
'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
|
||||
'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
|
||||
'user_input_form': user_input_form,
|
||||
'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
|
||||
{"enabled": False, "type": "", "configs": []}),
|
||||
'file_upload': features_dict.get('file_upload', {"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"]
|
||||
}}),
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
}
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get(
|
||||
"suggested_questions_after_answer", {"enabled": False}
|
||||
),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
|
||||
}
|
||||
|
||||
|
||||
|
@ -90,6 +93,7 @@ class ExploreAppMetaApi(InstalledAppResource):
|
|||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters',
|
||||
endpoint='installed_app_parameters')
|
||||
api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta')
|
||||
api.add_resource(
|
||||
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
|
||||
)
|
||||
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||
|
|
|
@ -8,28 +8,28 @@ from libs.login import login_required
|
|||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
app_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"mode": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
}
|
||||
|
||||
recommended_app_fields = {
|
||||
'app': fields.Nested(app_fields, attribute='app'),
|
||||
'app_id': fields.String,
|
||||
'description': fields.String(attribute='description'),
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
'custom_disclaimer': fields.String,
|
||||
'category': fields.String,
|
||||
'position': fields.Integer,
|
||||
'is_listed': fields.Boolean
|
||||
"app": fields.Nested(app_fields, attribute="app"),
|
||||
"app_id": fields.String,
|
||||
"description": fields.String(attribute="description"),
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_list_fields = {
|
||||
'recommended_apps': fields.List(fields.Nested(recommended_app_fields)),
|
||||
'categories': fields.List(fields.String)
|
||||
"recommended_apps": fields.List(fields.Nested(recommended_app_fields)),
|
||||
"categories": fields.List(fields.String),
|
||||
}
|
||||
|
||||
|
||||
|
@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource):
|
|||
def get(self):
|
||||
# language args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('language', type=str, location='args')
|
||||
parser.add_argument("language", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.get('language') and args.get('language') in languages:
|
||||
language_prefix = args.get('language')
|
||||
if args.get("language") and args.get("language") in languages:
|
||||
language_prefix = args.get("language")
|
||||
elif current_user and current_user.interface_language:
|
||||
language_prefix = current_user.interface_language
|
||||
else:
|
||||
|
@ -61,5 +61,5 @@ class RecommendedAppApi(Resource):
|
|||
return RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
|
||||
api.add_resource(RecommendedAppListApi, '/explore/apps')
|
||||
api.add_resource(RecommendedAppApi, '/explore/apps/<uuid:app_id>')
|
||||
api.add_resource(RecommendedAppListApi, "/explore/apps")
|
||||
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")
|
||||
|
|
|
@ -11,56 +11,54 @@ from libs.helper import TimestampField, uuid_value
|
|||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
feedback_fields = {"rating": fields.String}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'created_at': TimestampField
|
||||
"id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != 'completion':
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit'])
|
||||
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
||||
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != 'completion':
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=uuid_value, required=True, location='json')
|
||||
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
SavedMessageService.save(app_model, current_user, args['message_id'])
|
||||
SavedMessageService.save(app_model, current_user, args["message_id"])
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {'result': 'success'}
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class SavedMessageApi(InstalledAppResource):
|
||||
|
@ -69,13 +67,21 @@ class SavedMessageApi(InstalledAppResource):
|
|||
|
||||
message_id = str(message_id)
|
||||
|
||||
if app_model.mode != 'completion':
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return {'result': 'success'}
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(SavedMessageListApi, '/installed-apps/<uuid:installed_app_id>/saved-messages', endpoint='installed_app_saved_messages')
|
||||
api.add_resource(SavedMessageApi, '/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>', endpoint='installed_app_saved_message')
|
||||
api.add_resource(
|
||||
SavedMessageListApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/saved-messages",
|
||||
endpoint="installed_app_saved_messages",
|
||||
)
|
||||
api.add_resource(
|
||||
SavedMessageApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
|
||||
endpoint="installed_app_saved_message",
|
||||
)
|
||||
|
|
|
@ -35,17 +35,13 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=True
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
@ -76,10 +72,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
|||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {
|
||||
"result": "success"
|
||||
}
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps/<uuid:installed_app_id>/workflows/run')
|
||||
api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop')
|
||||
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
api.add_resource(
|
||||
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
|
||||
)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user