mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Merge branch 'fix/chore-fix' into dev/plugin-deploy
This commit is contained in:
commit
7cea6c1713
|
@ -1,3 +1,3 @@
|
|||
#!/bin/bash
|
||||
|
||||
poetry install -C api
|
||||
cd api && poetry install
|
4
.github/workflows/build-push.yml
vendored
4
.github/workflows/build-push.yml
vendored
|
@ -50,7 +50,7 @@ jobs:
|
|||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
@ -115,7 +115,7 @@ jobs:
|
|||
merge-multiple: true
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
|
153
README.md
153
README.md
|
@ -46,6 +46,56 @@
|
|||
</p>
|
||||
|
||||
|
||||
## Table of Content
|
||||
0. [Quick-Start🚀](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)
|
||||
|
||||
1. [Intro📖](https://github.com/langgenius/dify?tab=readme-ov-file#intro)
|
||||
|
||||
2. [How to use🔧](https://github.com/langgenius/dify?tab=readme-ov-file#using-dify)
|
||||
|
||||
3. [Stay Ahead🏃](https://github.com/langgenius/dify?tab=readme-ov-file#staying-ahead)
|
||||
|
||||
4. [Next Steps🏹](https://github.com/langgenius/dify?tab=readme-ov-file#next-steps)
|
||||
|
||||
5. [Contributing💪](https://github.com/langgenius/dify?tab=readme-ov-file#contributing)
|
||||
|
||||
6. [Community and Contact🏠](https://github.com/langgenius/dify?tab=readme-ov-file#community--contact)
|
||||
|
||||
7. [Star-History📈](https://github.com/langgenius/dify?tab=readme-ov-file#star-history)
|
||||
|
||||
8. [Security🔒](https://github.com/langgenius/dify?tab=readme-ov-file#security-disclosure)
|
||||
|
||||
9. [License🤝](https://github.com/langgenius/dify?tab=readme-ov-file#license)
|
||||
|
||||
> Make sure you read through this README before you start utilizing Dify😊
|
||||
|
||||
|
||||
## Quick start
|
||||
The quickest way to deploy Dify locally is to run our [docker-compose.yml](https://github.com/langgenius/dify/blob/main/docker/docker-compose.yaml). Follow the instructions to start in 5 minutes.
|
||||
|
||||
> Before installing Dify, make sure your machine meets the following minimum system requirements:
|
||||
>
|
||||
>- CPU >= 2 Core
|
||||
>- RAM >= 4 GiB
|
||||
>- Docker and Docker Compose Installed
|
||||
</br>
|
||||
|
||||
Run the following command in your terminal to clone the whole repo.
|
||||
```bash
|
||||
git clone https://github.com/langgenius/dify.git
|
||||
```
|
||||
After cloning,run the following command one by one.
|
||||
```bash
|
||||
cd dify
|
||||
cd docker
|
||||
cp .env.example .env
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. You will be asked to setup an admin account.
|
||||
For more info of quick setup, check [here](https://docs.dify.ai/getting-started/install-self-hosted/docker-compose)
|
||||
|
||||
## Intro
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
|
||||
</br> </br>
|
||||
|
||||
|
@ -79,73 +129,6 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
|||
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
|
||||
|
||||
|
||||
## Feature comparison
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<th align="center">Feature</th>
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
<th align="center">Flowise</th>
|
||||
<th align="center">OpenAI Assistants API</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Programming Approach</td>
|
||||
<td align="center">API + App-oriented</td>
|
||||
<td align="center">Python Code</td>
|
||||
<td align="center">App-oriented</td>
|
||||
<td align="center">API-oriented</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Supported LLMs</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">OpenAI-only</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">RAG Engine</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Workflow</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Observability</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Enterprise Features (SSO/Access control)</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Local Deployment</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Using Dify
|
||||
|
||||
- **Cloud </br>**
|
||||
|
@ -166,30 +149,21 @@ Star Dify on GitHub and be instantly notified of new releases.
|
|||
|
||||
![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4)
|
||||
|
||||
|
||||
|
||||
## Quick start
|
||||
> Before installing Dify, make sure your machine meets the following minimum system requirements:
|
||||
>
|
||||
>- CPU >= 2 Core
|
||||
>- RAM >= 4 GiB
|
||||
|
||||
</br>
|
||||
|
||||
The easiest way to start the Dify server is to run our [docker-compose.yml](docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
cp .env.example .env
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
|
||||
|
||||
> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
|
||||
|
||||
## Next steps
|
||||
|
||||
Go to [quick-start](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start) to setup your Dify or setup by source code.
|
||||
|
||||
#### If you......
|
||||
If you forget your admin account, you can refer to this [guide](https://docs.dify.ai/getting-started/install-self-hosted/faqs#id-4.-how-to-reset-the-password-of-the-admin-account) to reset the password.
|
||||
|
||||
> Use docker compose up without "-d" to enable logs printing out in your terminal. This might be useful if you have encountered unknow problems when using Dify.
|
||||
|
||||
If you encountered system error and would like to acquire help in Github issues, make sure you always paste logs of the error in the request to accerate the conversation. Go to [Community & contact](https://github.com/langgenius/dify?tab=readme-ov-file#community--contact) for more information.
|
||||
|
||||
> Please read the [Dify Documentation](https://docs.dify.ai/) for detailed how-to-use guidance. Most of the potential problems are explained in the doc.
|
||||
|
||||
> If you'd like to contribute to Dify or make additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
|
@ -228,6 +202,7 @@ At the same time, please consider supporting Dify by sharing it on social media
|
|||
* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
|
||||
* [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
|
||||
* Make sure a log, if possible, is attached to an error reported to maximize solution efficiency.
|
||||
|
||||
## Star history
|
||||
|
||||
|
|
|
@ -55,7 +55,12 @@ RUN apt-get update \
|
|||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends libsqlite3-0=3.46.1-1 \
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
|
||||
&& if [ "$(dpkg --print-architecture)" = "amd64" ]; then \
|
||||
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1+b1; \
|
||||
else \
|
||||
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1; \
|
||||
fi \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
&& apt-get autoremove -y \
|
||||
|
|
|
@ -10,7 +10,6 @@ from pydantic import (
|
|||
PositiveInt,
|
||||
computed_field,
|
||||
)
|
||||
from pydantic_extra_types.timezone_name import TimeZoneName
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.feature.hosted_service import HostedServiceConfig
|
||||
|
@ -393,9 +392,8 @@ class LoggingConfig(BaseSettings):
|
|||
default=None,
|
||||
)
|
||||
|
||||
LOG_TZ: Optional[TimeZoneName] = Field(
|
||||
description="Timezone for log timestamps. Allowed timezone values can be referred to IANA Time Zone Database,"
|
||||
" e.g., 'America/New_York')",
|
||||
LOG_TZ: Optional[str] = Field(
|
||||
description="Timezone for log timestamps (e.g., 'America/New_York')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from configs.middleware.storage.supabase_storage_config import SupabaseStorageCo
|
|||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
|
@ -259,5 +260,6 @@ class MiddlewareConfig(
|
|||
UpstashConfig,
|
||||
TidbOnQdrantConfig,
|
||||
OceanBaseVectorConfig,
|
||||
BaiduVectorDBConfig,
|
||||
):
|
||||
pass
|
||||
|
|
6
api/controllers/common/errors.py
Normal file
6
api/controllers/common/errors.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
|
||||
class FilenameNotExistsError(HTTPException):
|
||||
code = 400
|
||||
description = "The specified filename does not exist."
|
58
api/controllers/common/helpers.py
Normal file
58
api/controllers/common/helpers.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
filename: str
|
||||
extension: str
|
||||
mimetype: str
|
||||
size: int
|
||||
|
||||
|
||||
def guess_file_info_from_response(response: httpx.Response):
|
||||
url = str(response.url)
|
||||
# Try to extract filename from URL
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
filename = os.path.basename(url_path)
|
||||
|
||||
# If filename couldn't be extracted, use Content-Disposition header
|
||||
if not filename:
|
||||
content_disposition = response.headers.get("Content-Disposition")
|
||||
if content_disposition:
|
||||
filename_match = re.search(r'filename="?(.+)"?', content_disposition)
|
||||
if filename_match:
|
||||
filename = filename_match.group(1)
|
||||
|
||||
# If still no filename, generate a unique one
|
||||
if not filename:
|
||||
unique_name = str(uuid4())
|
||||
filename = f"{unique_name}"
|
||||
|
||||
# Guess MIME type from filename first, then URL
|
||||
mimetype, _ = mimetypes.guess_type(filename)
|
||||
if mimetype is None:
|
||||
mimetype, _ = mimetypes.guess_type(url)
|
||||
if mimetype is None:
|
||||
# If guessing fails, use Content-Type from response headers
|
||||
mimetype = response.headers.get("Content-Type", "application/octet-stream")
|
||||
|
||||
extension = os.path.splitext(filename)[1]
|
||||
|
||||
# Ensure filename has an extension
|
||||
if not extension:
|
||||
extension = mimetypes.guess_extension(mimetype) or ".bin"
|
||||
filename = f"{filename}{extension}"
|
||||
|
||||
return FileInfo(
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mimetype=mimetype,
|
||||
size=int(response.headers.get("Content-Length", -1)),
|
||||
)
|
|
@ -2,9 +2,21 @@ from flask import Blueprint
|
|||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
|
||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# File
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
|
||||
# Remote files
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
|
||||
|
@ -43,7 +55,6 @@ from .datasets import (
|
|||
datasets_document,
|
||||
datasets_segments,
|
||||
external,
|
||||
file,
|
||||
hit_testing,
|
||||
website,
|
||||
)
|
||||
|
|
|
@ -12,8 +12,7 @@ from models.dataset import Dataset
|
|||
from models.model import ApiToken, App
|
||||
|
||||
from . import api
|
||||
from .setup import setup_required
|
||||
from .wraps import account_initialization_required
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
|
||||
api_key_fields = {
|
||||
"id": fields.String,
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
|
|
@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
|
|
|
@ -6,8 +6,11 @@ from werkzeug.exceptions import Forbidden
|
|||
from controllers.console import api
|
||||
from controllers.console.app.error import NoFileUploadedError
|
||||
from controllers.console.datasets.error import TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
|
|
|
@ -6,8 +6,11 @@ from werkzeug.exceptions import BadRequest, Forbidden, abort
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
|
|
|
@ -18,8 +18,7 @@ from controllers.console.app.error import (
|
|||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
|
|
|
@ -15,8 +15,7 @@ from controllers.console.app.error import (
|
|||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
|
@ -10,8 +10,7 @@ from werkzeug.exceptions import Forbidden, NotFound
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
|
|
|
@ -4,8 +4,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_variable_fields import paginated_conversation_variable_fields
|
||||
from libs.login import login_required
|
||||
|
|
|
@ -10,8 +10,7 @@ from controllers.console.app.error import (
|
|||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
|
|
@ -14,8 +14,11 @@ from controllers.console.app.error import (
|
|||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
|
|
@ -6,8 +6,7 @@ from flask_restful import Resource
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
|
|
|
@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.ops_service import OpsService
|
||||
|
||||
|
|
|
@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden, NotFound
|
|||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.login import login_required
|
||||
|
|
|
@ -8,8 +8,7 @@ from flask_restful import Resource, reqparse
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
|
|
|
@ -9,8 +9,7 @@ import services
|
|||
from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from factories import variable_factory
|
||||
|
|
|
@ -3,8 +3,7 @@ from flask_restful.inputs import int_range
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
|
|
|
@ -3,8 +3,7 @@ from flask_restful.inputs import int_range
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.workflow_run_fields import (
|
||||
advanced_chat_workflow_run_pagination_fields,
|
||||
workflow_run_detail_fields,
|
||||
|
|
|
@ -8,8 +8,7 @@ from flask_restful import Resource, reqparse
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
|
|
|
@ -7,8 +7,7 @@ from controllers.console.auth.error import ApiKeyAuthFailedError
|
|||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..setup import setup_required
|
||||
from ..wraps import account_initialization_required
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
class ApiKeyAuthDataSource(Resource):
|
||||
|
|
|
@ -11,8 +11,7 @@ from controllers.console import api
|
|||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
from ..setup import setup_required
|
||||
from ..wraps import account_initialization_required
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
|
|
|
@ -15,7 +15,7 @@ from controllers.console.auth.error import (
|
|||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
|
|
|
@ -20,7 +20,7 @@ from controllers.console.error import (
|
|||
NotAllowedCreateWorkspace,
|
||||
NotAllowedRegister,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
|
|
|
@ -2,8 +2,7 @@ from flask_login import current_user
|
|||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
|
|
@ -9,8 +9,7 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
|
|
|
@ -10,8 +10,7 @@ from controllers.console import api
|
|||
from controllers.console.apikey import api_key_fields, api_key_list
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
|
|
@ -23,8 +23,11 @@ from controllers.console.datasets.error import (
|
|||
InvalidActionError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
|
|
|
@ -11,11 +11,11 @@ import services
|
|||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_knowledge_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
|
|
|
@ -6,8 +6,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
|||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
|
|
|
@ -2,8 +2,7 @@ from flask_restful import Resource
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
|
||||
|
||||
|
|
|
@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.website_service import WebsiteService
|
||||
|
||||
|
|
|
@ -3,8 +3,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
|||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
|
|
|
@ -5,8 +5,7 @@ from libs.login import login_required
|
|||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api
|
||||
from .setup import setup_required
|
||||
from .wraps import account_initialization_required, cloud_utm_record
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
|
||||
class FeatureApi(Resource):
|
||||
|
|
|
@ -1,25 +1,26 @@
|
|||
import urllib.parse
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import (
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from services.file_service import FileService
|
||||
|
||||
from .errors import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from services.file_service import FileService
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
@ -44,21 +45,29 @@ class FileApi(Resource):
|
|||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
source = request.form.get("source")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("source", type=str, required=False, location="args")
|
||||
source = parser.parse_args().get("source")
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file=file, user=current_user, source=source)
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source=source,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
|
@ -83,23 +92,3 @@ class FileSupportTypeApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
|
||||
|
||||
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
25
api/controllers/console/files/errors.py
Normal file
25
api/controllers/console/files/errors.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
71
api/controllers/console/remote_files.py
Normal file
71
api/controllers/console/remote_files.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
|
||||
from controllers.common import helpers
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("url", type=str, required=True, help="URL is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
url = args["url"]
|
||||
|
||||
response = ssrf_proxy.head(url)
|
||||
response.raise_for_status()
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(response)
|
||||
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
return {"error": "File size exceeded"}, 400
|
||||
|
||||
response = ssrf_proxy.get(url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
|
||||
try:
|
||||
user = cast(Account, current_user)
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=user,
|
||||
source_url=url,
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
|
@ -1,5 +1,3 @@
|
|||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
|
@ -10,7 +8,7 @@ from models.model import DifySetup, db
|
|||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
from . import api
|
||||
from .error import AlreadySetupError, NotInitValidateError, NotSetupError
|
||||
from .error import AlreadySetupError, NotInitValidateError
|
||||
from .init_validate import get_init_validate_status
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
|
@ -52,21 +50,6 @@ class SetupApi(Resource):
|
|||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
def setup_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# check setup
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
elif not get_setup_status():
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def get_setup_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return db.session.query(DifySetup).first()
|
||||
|
|
|
@ -4,8 +4,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.tag_fields import tag_fields
|
||||
from libs.login import login_required
|
||||
from models.model import Tag
|
||||
|
|
|
@ -3,6 +3,7 @@ import logging
|
|||
|
||||
import requests
|
||||
from flask_restful import Resource, reqparse
|
||||
from packaging import version
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
@ -47,42 +48,14 @@ class VersionApi(Resource):
|
|||
|
||||
|
||||
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
|
||||
def parse_version(version: str) -> tuple:
|
||||
# Split version into parts and pre-release suffix if any
|
||||
parts = version.split("-")
|
||||
version_parts = parts[0].split(".")
|
||||
pre_release = parts[1] if len(parts) > 1 else None
|
||||
|
||||
# Validate version format
|
||||
if len(version_parts) != 3:
|
||||
raise ValueError(f"Invalid version format: {version}")
|
||||
|
||||
try:
|
||||
# Convert version parts to integers
|
||||
major, minor, patch = map(int, version_parts)
|
||||
return (major, minor, patch, pre_release)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid version format: {version}")
|
||||
|
||||
latest = parse_version(latest_version)
|
||||
current = parse_version(current_version)
|
||||
|
||||
# Compare major, minor, and patch versions
|
||||
for latest_part, current_part in zip(latest[:3], current[:3]):
|
||||
if latest_part > current_part:
|
||||
return True
|
||||
elif latest_part < current_part:
|
||||
return False
|
||||
|
||||
# If versions are equal, check pre-release suffixes
|
||||
if latest[3] is None and current[3] is not None:
|
||||
return True
|
||||
elif latest[3] is not None and current[3] is None:
|
||||
return False
|
||||
elif latest[3] is not None and current[3] is not None:
|
||||
# Simple string comparison for pre-release versions
|
||||
return latest[3] > current[3]
|
||||
latest = version.parse(latest_version)
|
||||
current = version.parse(current_version)
|
||||
|
||||
# Compare versions
|
||||
return latest > current
|
||||
except version.InvalidVersion:
|
||||
logging.warning(f"Invalid version format: latest={latest_version}, current={current_version}")
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
@ -8,14 +8,13 @@ from flask_restful import Resource, fields, marshal_with, reqparse
|
|||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.workspace.error import (
|
||||
AccountAlreadyInitedError,
|
||||
CurrentPasswordIncorrectError,
|
||||
InvalidInvitationCodeError,
|
||||
RepeatPasswordNotMatchError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.helper import TimestampField, timezone
|
||||
|
|
|
@ -3,8 +3,7 @@ from flask_restful import Resource, reqparse
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
|
|
@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_user, login_required
|
||||
|
|
|
@ -4,8 +4,11 @@ from flask_restful import Resource, abort, marshal_with, reqparse
|
|||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.login import login_required
|
||||
|
|
|
@ -6,8 +6,7 @@ from flask_restful import Resource, reqparse
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
|
@ -5,8 +5,7 @@ from flask_restful import Resource, reqparse
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
|
@ -7,9 +7,8 @@ from werkzeug.exceptions import Forbidden
|
|||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import TenantPluginPermission
|
||||
|
|
|
@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden
|
|||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
|
|
|
@ -6,6 +6,7 @@ from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqpa
|
|||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.datasets.error import (
|
||||
|
@ -15,8 +16,11 @@ from controllers.console.datasets.error import (
|
|||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
|
@ -193,12 +197,20 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
extension = file.filename.split(".")[-1]
|
||||
if extension.lower() not in {"svg", "png"}:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file=file, user=current_user)
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
from flask import abort, request
|
||||
|
@ -6,9 +7,13 @@ from flask_login import current_user
|
|||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_database import db
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService
|
||||
from services.operation_service import OperationService
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError
|
||||
|
||||
|
||||
def account_initialization_required(view):
|
||||
@wraps(view)
|
||||
|
@ -124,3 +129,21 @@ def cloud_utm_record(view):
|
|||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# check setup
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
and os.environ.get("INIT_PASSWORD")
|
||||
and not db.session.query(DifySetup).first()
|
||||
):
|
||||
raise NotInitValidateError()
|
||||
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from flask_restful import Resource
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
|
||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from events.tenant_event import tenant_was_created
|
||||
|
|
|
@ -2,6 +2,7 @@ from flask import request
|
|||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
FileTooLargeError,
|
||||
|
@ -31,8 +32,17 @@ class FileApi(Resource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, end_user)
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=end_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
|
|
|
@ -6,6 +6,7 @@ from sqlalchemy import desc
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.dataset.error import (
|
||||
|
@ -55,7 +56,12 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
upload_file = FileService.upload_text(args.get("text"), args.get("name"))
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||
|
||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
|
@ -104,7 +110,11 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
raise ValueError("Dataset is not exist.")
|
||||
|
||||
if args["text"]:
|
||||
upload_file = FileService.upload_text(args.get("text"), args.get("name"))
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both text and name must be strings.")
|
||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
|
@ -163,7 +173,16 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
|
@ -212,7 +231,16 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
|
@ -331,10 +359,26 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
|||
return data
|
||||
|
||||
|
||||
api.add_resource(DocumentAddByTextApi, "/datasets/<uuid:dataset_id>/document/create_by_text")
|
||||
api.add_resource(DocumentAddByFileApi, "/datasets/<uuid:dataset_id>/document/create_by_file")
|
||||
api.add_resource(DocumentUpdateByTextApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
|
||||
api.add_resource(DocumentUpdateByFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file")
|
||||
api.add_resource(
|
||||
DocumentAddByTextApi,
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_text",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-text",
|
||||
)
|
||||
api.add_resource(
|
||||
DocumentAddByFileApi,
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_file",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-file",
|
||||
)
|
||||
api.add_resource(
|
||||
DocumentUpdateByTextApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
|
||||
)
|
||||
api.add_resource(
|
||||
DocumentUpdateByFileApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||
|
|
|
@ -14,4 +14,4 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
|||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
||||
|
|
|
@ -2,8 +2,17 @@ from flask import Blueprint
|
|||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .files import FileApi
|
||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
bp = Blueprint("web", __name__, url_prefix="/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# Files
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
|
||||
from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow
|
||||
# Remote files
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
import urllib.parse
|
||||
|
||||
from flask import request
|
||||
from flask_restful import marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields, remote_file_info_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class FileApi(WebApiResource):
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("source", type=str, required=False, location="args")
|
||||
source = parser.parse_args().get("source")
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
try:
|
||||
upload_file = FileService.upload_file(file=file, user=end_user, source=source)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
class RemoteFileInfoApi(WebApiResource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", -1)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
43
api/controllers/web/files.py
Normal file
43
api/controllers/web/files.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.file_fields import file_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class FileApi(WebApiResource):
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
file = request.files["file"]
|
||||
source = request.form.get("source")
|
||||
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=end_user,
|
||||
source=source,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
69
api/controllers/web/remote_files.py
Normal file
69
api/controllers/web/remote_files.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
import urllib.parse
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
|
||||
from controllers.common import helpers
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class RemoteFileInfoApi(WebApiResource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", -1)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
class RemoteFileUploadApi(WebApiResource):
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("url", type=str, required=True, help="URL is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
url = args["url"]
|
||||
|
||||
response = ssrf_proxy.head(url)
|
||||
response.raise_for_status()
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(response)
|
||||
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
return {"error": "File size exceeded"}, 400
|
||||
|
||||
response = ssrf_proxy.get(url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=current_user,
|
||||
source_url=url,
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
|
@ -17,6 +17,7 @@ from core.errors.error import ProviderTokenNotInitError
|
|||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
|
@ -597,26 +598,9 @@ class IndexingRunner:
|
|||
rules = DatasetProcessRule.AUTOMATIC_RULES
|
||||
else:
|
||||
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
|
||||
document_text = CleanProcessor.clean(text, rules)
|
||||
|
||||
if "pre_processing_rules" in rules:
|
||||
pre_processing_rules = rules["pre_processing_rules"]
|
||||
for pre_processing_rule in pre_processing_rules:
|
||||
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
|
||||
# Remove extra spaces
|
||||
pattern = r"\n{3,}"
|
||||
text = re.sub(pattern, "\n\n", text)
|
||||
pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}"
|
||||
text = re.sub(pattern, " ", text)
|
||||
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
|
||||
# Remove email
|
||||
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
|
||||
text = re.sub(pattern, "", text)
|
||||
|
||||
# Remove URL
|
||||
pattern = r"https?://[^\s]+"
|
||||
text = re.sub(pattern, "", text)
|
||||
|
||||
return text
|
||||
return document_text
|
||||
|
||||
@staticmethod
|
||||
def format_split_text(text):
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
- claude-3-5-sonnet-20241022
|
||||
- claude-3-5-sonnet-20240620
|
||||
- claude-3-haiku-20240307
|
||||
- claude-3-opus-20240229
|
||||
- claude-3-sonnet-20240229
|
||||
- claude-2.1
|
||||
- claude-instant-1.2
|
||||
- claude-2
|
||||
- claude-instant-1
|
|
@ -1,39 +0,0 @@
|
|||
model: claude-3-5-sonnet-20241022
|
||||
label:
|
||||
en_US: claude-3-5-sonnet-20241022
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '3.00'
|
||||
output: '15.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -1,245 +0,0 @@
|
|||
provider: azure_openai
|
||||
label:
|
||||
en_US: Azure OpenAI Service Model
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#E3F0FF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from Azure
|
||||
zh_Hans: 从 Azure 获取 API Key
|
||||
url:
|
||||
en_US: https://azure.microsoft.com/en-us/products/ai-services/openai-service
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- speech2text
|
||||
- tts
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Deployment Name
|
||||
zh_Hans: 部署名称
|
||||
placeholder:
|
||||
en_US: Enter your Deployment Name here, matching the Azure deployment name.
|
||||
zh_Hans: 在此输入您的部署名称,与 Azure 部署名称匹配。
|
||||
credential_form_schemas:
|
||||
- variable: openai_api_base
|
||||
label:
|
||||
en_US: API Endpoint URL
|
||||
zh_Hans: API 域名
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: '在此输入您的 API 域名,如:https://example.com/xxx'
|
||||
en_US: 'Enter your API Endpoint, eg: https://example.com/xxx'
|
||||
- variable: openai_api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
zh_Hans: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API key here
|
||||
- variable: openai_api_version
|
||||
label:
|
||||
zh_Hans: API 版本
|
||||
en_US: API Version
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: 2024-10-01-preview
|
||||
value: 2024-10-01-preview
|
||||
- label:
|
||||
en_US: 2024-09-01-preview
|
||||
value: 2024-09-01-preview
|
||||
- label:
|
||||
en_US: 2024-08-01-preview
|
||||
value: 2024-08-01-preview
|
||||
- label:
|
||||
en_US: 2024-07-01-preview
|
||||
value: 2024-07-01-preview
|
||||
- label:
|
||||
en_US: 2024-05-01-preview
|
||||
value: 2024-05-01-preview
|
||||
- label:
|
||||
en_US: 2024-04-01-preview
|
||||
value: 2024-04-01-preview
|
||||
- label:
|
||||
en_US: 2024-03-01-preview
|
||||
value: 2024-03-01-preview
|
||||
- label:
|
||||
en_US: 2024-02-15-preview
|
||||
value: 2024-02-15-preview
|
||||
- label:
|
||||
en_US: 2023-12-01-preview
|
||||
value: 2023-12-01-preview
|
||||
- label:
|
||||
en_US: '2024-02-01'
|
||||
value: '2024-02-01'
|
||||
- label:
|
||||
en_US: '2024-06-01'
|
||||
value: '2024-06-01'
|
||||
placeholder:
|
||||
zh_Hans: 在此选择您的 API 版本
|
||||
en_US: Select your API Version here
|
||||
- variable: base_model_name
|
||||
label:
|
||||
en_US: Base Model
|
||||
zh_Hans: 基础模型
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: gpt-35-turbo
|
||||
value: gpt-35-turbo
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-0125
|
||||
value: gpt-35-turbo-0125
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-16k
|
||||
value: gpt-35-turbo-16k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4
|
||||
value: gpt-4
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-32k
|
||||
value: gpt-4-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: o1-mini
|
||||
value: o1-mini
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: o1-preview
|
||||
value: o1-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-mini
|
||||
value: gpt-4o-mini
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-mini-2024-07-18
|
||||
value: gpt-4o-mini-2024-07-18
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o
|
||||
value: gpt-4o
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-05-13
|
||||
value: gpt-4o-2024-05-13
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-08-06
|
||||
value: gpt-4o-2024-08-06
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo-2024-04-09
|
||||
value: gpt-4-turbo-2024-04-09
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-0125-preview
|
||||
value: gpt-4-0125-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-1106-preview
|
||||
value: gpt-4-1106-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-vision-preview
|
||||
value: gpt-4-vision-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-instruct
|
||||
value: gpt-35-turbo-instruct
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: text-embedding-ada-002
|
||||
value: text-embedding-ada-002
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: text-embedding-3-small
|
||||
value: text-embedding-3-small
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: text-embedding-3-large
|
||||
value: text-embedding-3-large
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: whisper-1
|
||||
value: whisper-1
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: speech2text
|
||||
- label:
|
||||
en_US: tts-1
|
||||
value: tts-1
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
- label:
|
||||
en_US: tts-1-hd
|
||||
value: tts-1-hd
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型版本
|
||||
en_US: Enter your model version
|
|
@ -1,764 +0,0 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import AzureOpenAI, Stream
|
||||
from openai.types import Completion
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageFunction,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
return self._chat_generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
# text completion model
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
if not model_entity:
|
||||
raise ValueError(f"Base Model Name {base_model_name} is invalid")
|
||||
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
|
||||
|
||||
if model_mode == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
else:
|
||||
# text completion model, do not support tool calling
|
||||
content = prompt_messages[0].content
|
||||
assert isinstance(content, str)
|
||||
return self._num_tokens_from_string(credentials, content)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
if "openai_api_base" not in credentials:
|
||||
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
|
||||
|
||||
if "openai_api_key" not in credentials:
|
||||
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
|
||||
|
||||
if "base_model_name" not in credentials:
|
||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if not ai_model_entity:
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
|
||||
try:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
if model.startswith("o1"):
|
||||
client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
model=model,
|
||||
temperature=1,
|
||||
max_completion_tokens=20,
|
||||
stream=False,
|
||||
)
|
||||
elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
model=model,
|
||||
temperature=0,
|
||||
max_tokens=20,
|
||||
stream=False,
|
||||
)
|
||||
else:
|
||||
# text completion model
|
||||
client.completions.create(
|
||||
prompt="ping",
|
||||
model=model,
|
||||
temperature=0,
|
||||
max_tokens=20,
|
||||
stream=False,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
return ai_model_entity.entity if ai_model_entity else None
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if stop:
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
# text completion model
|
||||
response = client.completions.create(
|
||||
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage]
|
||||
):
|
||||
assistant_text = response.choices[0].text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
content = prompt_messages[0].content
|
||||
assert isinstance(content, str)
|
||||
prompt_tokens = self._num_tokens_from_string(credentials, content)
|
||||
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=response.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
full_text = ""
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (delta.text is None or delta.text == ""):
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
text = delta.text or ""
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||
|
||||
full_text += text
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
# calculate num tokens
|
||||
if chunk.usage:
|
||||
# transform usage
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
content = prompt_messages[0].content
|
||||
assert isinstance(content, str)
|
||||
prompt_tokens = self._num_tokens_from_string(credentials, content)
|
||||
completion_tokens = self._num_tokens_from_string(credentials, full_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
def _chat_generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
response_format = model_parameters.get("response_format")
|
||||
if response_format:
|
||||
if response_format == "json_schema":
|
||||
json_schema = model_parameters.get("json_schema")
|
||||
if not json_schema:
|
||||
raise ValueError("Must define JSON Schema when the response format is json_schema")
|
||||
try:
|
||||
schema = json.loads(json_schema)
|
||||
except:
|
||||
raise ValueError(f"not correct json_schema format: {json_schema}")
|
||||
model_parameters.pop("json_schema")
|
||||
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
||||
else:
|
||||
model_parameters["response_format"] = {"type": response_format}
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
||||
|
||||
if stop:
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
# clear illegal prompt messages
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
if stream:
|
||||
block_as_stream = True
|
||||
stream = False
|
||||
|
||||
if "stream_options" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stream_options"]
|
||||
|
||||
if "stop" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stop"]
|
||||
|
||||
# chat model
|
||||
response = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
if block_as_stream:
|
||||
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
||||
|
||||
return block_result
|
||||
|
||||
def _handle_chat_block_as_stream_response(
|
||||
self,
|
||||
block_result: LLMResult,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
text = block_result.message.content
|
||||
text = cast(str, text)
|
||||
|
||||
if stop:
|
||||
text = self.enforce_stop_tokens(text, stop)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=block_result.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=block_result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
finish_reason="stop",
|
||||
usage=block_result.usage,
|
||||
),
|
||||
)
|
||||
|
||||
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Clear illegal prompt messages for OpenAI API
|
||||
|
||||
:param model: model name
|
||||
:param prompt_messages: prompt messages
|
||||
:return: cleaned prompt messages
|
||||
"""
|
||||
checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"]
|
||||
|
||||
if model in checklist:
|
||||
# count how many user messages are there
|
||||
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
|
||||
if user_message_count > 1:
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
item.data
|
||||
if item.type == PromptMessageContentType.TEXT
|
||||
else "[IMAGE]"
|
||||
if item.type == PromptMessageContentType.IMAGE
|
||||
else ""
|
||||
for item in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
if model.startswith("o1"):
|
||||
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
|
||||
if system_message_count > 0:
|
||||
new_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_message = UserPromptMessage(
|
||||
content=prompt_message.content,
|
||||
name=prompt_message.name,
|
||||
)
|
||||
|
||||
new_prompt_messages.append(prompt_message)
|
||||
prompt_messages = new_prompt_messages
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
):
|
||||
assistant_message = response.choices[0].message
|
||||
assistant_message_tool_calls = assistant_message.tool_calls
|
||||
|
||||
# extract tool calls from response
|
||||
tool_calls = []
|
||||
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=response.model or model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
):
|
||||
index = 0
|
||||
full_assistant_content = ""
|
||||
real_model = model
|
||||
system_fingerprint = None
|
||||
completion = ""
|
||||
tool_calls = []
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
# NOTE: For fix https://github.com/langgenius/dify/issues/5790
|
||||
if delta.delta is None:
|
||||
continue
|
||||
|
||||
# extract tool calls from response
|
||||
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
|
||||
|
||||
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
|
||||
if delta.finish_reason is None and not delta.delta.content:
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
||||
|
||||
full_assistant_content += delta.delta.content or ""
|
||||
|
||||
real_model = chunk.model
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
completion += delta.delta.content or ""
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
index += 1
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
|
||||
full_assistant_prompt_message = AssistantPromptMessage(content=completion)
|
||||
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _update_tool_calls(
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall],
|
||||
tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]],
|
||||
) -> None:
|
||||
if tool_calls_response:
|
||||
for response_tool_call in tool_calls_response:
|
||||
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.id, type=response_tool_call.type, function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
|
||||
index = response_tool_call.index
|
||||
if index < len(tool_calls):
|
||||
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
|
||||
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
|
||||
if response_tool_call.function:
|
||||
tool_calls[index].function.name = (
|
||||
response_tool_call.function.name or tool_calls[index].function.name
|
||||
)
|
||||
tool_calls[index].function.arguments += response_tool_call.function.arguments or ""
|
||||
else:
|
||||
assert response_tool_call.id is not None
|
||||
assert response_tool_call.type is not None
|
||||
assert response_tool_call.function is not None
|
||||
assert response_tool_call.function.name is not None
|
||||
assert response_tool_call.function.arguments is not None
|
||||
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
|
||||
)
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.id, type=response_tool_call.type, function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
@staticmethod
|
||||
def _convert_prompt_message_to_dict(message: PromptMessage):
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
assert message.content is not None
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
# message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"name": message.name,
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_string(
|
||||
self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
model = credentials["base_model_name"]
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
|
||||
if model.startswith("gpt-35-turbo-0301"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
assert isinstance(tool_call, dict)
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
num_tokens += len(encoding.encode(t_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
num_tokens += len(encoding.encode("type"))
|
||||
num_tokens += len(encoding.encode("function"))
|
||||
|
||||
# calculate num tokens for function object
|
||||
num_tokens += len(encoding.encode("name"))
|
||||
num_tokens += len(encoding.encode(tool.name))
|
||||
num_tokens += len(encoding.encode("description"))
|
||||
num_tokens += len(encoding.encode(tool.description))
|
||||
parameters = tool.parameters
|
||||
num_tokens += len(encoding.encode("parameters"))
|
||||
if "title" in parameters:
|
||||
num_tokens += len(encoding.encode("title"))
|
||||
num_tokens += len(encoding.encode(parameters["title"]))
|
||||
num_tokens += len(encoding.encode("type"))
|
||||
num_tokens += len(encoding.encode(parameters["type"]))
|
||||
if "properties" in parameters:
|
||||
num_tokens += len(encoding.encode("properties"))
|
||||
for key, value in parameters["properties"].items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == "enum":
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if "required" in parameters:
|
||||
num_tokens += len(encoding.encode("required"))
|
||||
for required_field in parameters["required"]:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str):
|
||||
for ai_model_entity in LLM_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
ai_model_entity_copy.entity.model = model
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
def _get_base_model_name(self, credentials: dict) -> str:
|
||||
base_model_name = credentials.get("base_model_name")
|
||||
if not base_model_name:
|
||||
raise ValueError("Base Model Name is required")
|
||||
return base_model_name
|
|
@ -1,450 +0,0 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as genai
|
||||
import requests
|
||||
from google.api_core import exceptions
|
||||
from google.generativeai.client import _ClientManager
|
||||
from google.generativeai.types import ContentType, GenerateContentResponse
|
||||
from google.generativeai.types.content_types import to_part
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Google model
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(self._convert_one_message_to_text(message) for message in messages)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
|
||||
"""
|
||||
Convert tool messages to glm tools
|
||||
|
||||
:param tools: tool messages
|
||||
:return: glm tools
|
||||
"""
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
properties = {}
|
||||
for key, value in tool.parameters.get("properties", {}).items():
|
||||
properties[key] = {
|
||||
"type_": glm.Type.STRING,
|
||||
"description": value.get("description", ""),
|
||||
"enum": value.get("enum", []),
|
||||
}
|
||||
|
||||
if properties:
|
||||
parameters = glm.Schema(
|
||||
type=glm.Type.OBJECT,
|
||||
properties=properties,
|
||||
required=tool.parameters.get("required", []),
|
||||
)
|
||||
else:
|
||||
parameters = None
|
||||
|
||||
function_declaration = glm.FunctionDeclaration(
|
||||
name=tool.name,
|
||||
parameters=parameters,
|
||||
description=tool.description,
|
||||
)
|
||||
function_declarations.append(function_declaration)
|
||||
|
||||
return glm.Tool(function_declarations=function_declarations)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials kwargs
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
||||
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
google_model = genai.GenerativeModel(model_name=model)
|
||||
|
||||
history = []
|
||||
|
||||
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
||||
if model == "gemini-pro-vision":
|
||||
last_msg = prompt_messages[-1]
|
||||
content = self._format_message_to_glm_content(last_msg)
|
||||
history.append(content)
|
||||
else:
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
# Create a new ClientManager with tenant's API key
|
||||
new_client_manager = _ClientManager()
|
||||
new_client_manager.configure(api_key=credentials["google_api_key"])
|
||||
new_custom_client = new_client_manager.make_client("generative")
|
||||
|
||||
google_model._client = new_custom_client
|
||||
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
||||
stream=stream,
|
||||
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
|
||||
request_options={"timeout": 600},
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=response.text)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
index = -1
|
||||
for chunk in response:
|
||||
for part in chunk.parts:
|
||||
assistant_prompt_message = AssistantPromptMessage(content="")
|
||||
|
||||
if part.text:
|
||||
assistant_prompt_message.content += part.text
|
||||
|
||||
if part.function_call:
|
||||
assistant_prompt_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=part.function_call.name,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=part.function_call.name,
|
||||
arguments=json.dumps(dict(part.function_call.args.items())),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
index += 1
|
||||
|
||||
if not response._done:
|
||||
# transform assistant message to prompt message
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
|
||||
)
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=str(chunk.candidates[0].finish_reason),
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
:param message: PromptMessage to convert.
|
||||
:return: String representation of the message.
|
||||
"""
|
||||
human_prompt = "\n\nuser:"
|
||||
ai_prompt = "\n\nmodel:"
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_text
|
||||
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
|
||||
"""
|
||||
Format a single message into glm.Content for Google API
|
||||
|
||||
:param message: one PromptMessage
|
||||
:return: glm Content representation of message
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
glm_content = {"role": "user", "parts": []}
|
||||
if isinstance(message.content, str):
|
||||
glm_content["parts"].append(to_part(message.content))
|
||||
else:
|
||||
for c in message.content:
|
||||
if c.type == PromptMessageContentType.TEXT:
|
||||
glm_content["parts"].append(to_part(c.data))
|
||||
elif c.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, c)
|
||||
if message_content.data.startswith("data:"):
|
||||
metadata, base64_data = c.data.split(",", 1)
|
||||
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
||||
else:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
||||
glm_content["parts"].append(blob)
|
||||
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
glm_content = {"role": "model", "parts": []}
|
||||
if message.content:
|
||||
glm_content["parts"].append(to_part(message.content))
|
||||
if message.tool_calls:
|
||||
glm_content["parts"].append(
|
||||
to_part(
|
||||
glm.FunctionCall(
|
||||
name=message.tool_calls[0].function.name,
|
||||
args=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
)
|
||||
)
|
||||
return glm_content
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
return {"role": "user", "parts": [to_part(message.content)]}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
return {
|
||||
"role": "function",
|
||||
"parts": [
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
name=message.name, response={"response": message.content}
|
||||
)
|
||||
)
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [exceptions.RetryError],
|
||||
InvokeServerUnavailableError: [
|
||||
exceptions.ServiceUnavailable,
|
||||
exceptions.InternalServerError,
|
||||
exceptions.BadGateway,
|
||||
exceptions.GatewayTimeout,
|
||||
exceptions.DeadlineExceeded,
|
||||
],
|
||||
InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests],
|
||||
InvokeAuthorizationError: [
|
||||
exceptions.Unauthenticated,
|
||||
exceptions.PermissionDenied,
|
||||
exceptions.Unauthenticated,
|
||||
exceptions.Forbidden,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
exceptions.BadRequest,
|
||||
exceptions.InvalidArgument,
|
||||
exceptions.FailedPrecondition,
|
||||
exceptions.OutOfRange,
|
||||
exceptions.NotFound,
|
||||
exceptions.MethodNotAllowed,
|
||||
exceptions.Conflict,
|
||||
exceptions.AlreadyExists,
|
||||
exceptions.Aborted,
|
||||
exceptions.LengthRequired,
|
||||
exceptions.PreconditionFailed,
|
||||
exceptions.RequestRangeNotSatisfiable,
|
||||
exceptions.Cancelled,
|
||||
],
|
||||
}
|
|
@ -1,330 +0,0 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||
if credentials.get("function_calling_type") == "tool_call"
|
||||
else [],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
use_template="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="max_tokens",
|
||||
use_template="max_tokens",
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get("max_tokens", 4096)),
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||
type=ParameterType.INT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
use_template="top_p",
|
||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict) -> None:
|
||||
credentials["mode"] = "chat"
|
||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||
credentials["endpoint_url"] = "https://api.moonshot.cn/v1"
|
||||
|
||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection(
|
||||
model_schema.features or []
|
||||
):
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI API format
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = []
|
||||
for function_call in message.tool_calls:
|
||||
message_dict["tool_calls"].append(
|
||||
{
|
||||
"id": function_call.id,
|
||||
"type": function_call.type,
|
||||
"function": {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
:param response_tool_calls: response tool calls
|
||||
:return: list of tool calls
|
||||
"""
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call["function"]["name"]
|
||||
if response_tool_call.get("function", {}).get("name")
|
||||
else "",
|
||||
arguments=response_tool_call["function"]["arguments"]
|
||||
if response_tool_call.get("function", {}).get("arguments")
|
||||
else "",
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call["id"] if response_tool_call.get("id") else "",
|
||||
type=response_tool_call["type"] if response_tool_call.get("type") else "",
|
||||
function=function,
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param response: streamed response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ""
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(
|
||||
index: int, message: AssistantPromptMessage, finish_reason: str
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
)
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
finish_reason = "Unknown"
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_name: str):
|
||||
if not tool_name:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id="",
|
||||
type="",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered.",
|
||||
)
|
||||
break
|
||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json["choices"][0]
|
||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||
chunk_index += 1
|
||||
|
||||
if "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
delta_content = delta.get("content")
|
||||
|
||||
assistant_message_tool_calls = delta.get("tool_calls", None)
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||
)
|
||||
|
||||
full_assistant_content += delta_content
|
||||
elif "text" in choice:
|
||||
choice_text = choice.get("text", "")
|
||||
if choice_text == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
# check payload indicator for completion
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
|
||||
),
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason
|
||||
)
|
|
@ -1,847 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Union, cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageFunction,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
PriceConfig,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
"""
|
||||
Model class for OpenAI large language model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
# text completion model
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model:
|
||||
:param credentials:
|
||||
:param prompt_messages:
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials using requests to ensure compatibility with all providers following
|
||||
OpenAI's API standard.
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials["endpoint_url"]
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {"model": model, "max_tokens": 5}
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
data["messages"] = [
|
||||
{"role": "user", "content": "ping"},
|
||||
]
|
||||
endpoint_url = urljoin(endpoint_url, "chat/completions")
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
data["prompt"] = "ping"
|
||||
endpoint_url = urljoin(endpoint_url, "completions")
|
||||
else:
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
|
||||
# send a post request to validate the credentials
|
||||
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Credentials validation failed with status code {response.status_code}"
|
||||
)
|
||||
|
||||
try:
|
||||
json_result = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
|
||||
|
||||
if completion_type is LLMMode.CHAT and json_result.get("object", "") == "":
|
||||
json_result["object"] = "chat.completion"
|
||||
elif completion_type is LLMMode.COMPLETION and json_result.get("object", "") == "":
|
||||
json_result["object"] = "text_completion"
|
||||
|
||||
if completion_type is LLMMode.CHAT and (
|
||||
"object" not in json_result or json_result["object"] != "chat.completion"
|
||||
):
|
||||
raise CredentialsValidateFailedError(
|
||||
"Credentials validation failed: invalid response object, must be 'chat.completion'"
|
||||
)
|
||||
elif completion_type is LLMMode.COMPLETION and (
|
||||
"object" not in json_result or json_result["object"] != "text_completion"
|
||||
):
|
||||
raise CredentialsValidateFailedError(
|
||||
"Credentials validation failed: invalid response object, must be 'text_completion'"
|
||||
)
|
||||
except CredentialsValidateFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}")
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
features = []
|
||||
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
if function_calling_type == "function_call":
|
||||
features.append(ModelFeature.TOOL_CALL)
|
||||
elif function_calling_type == "tool_call":
|
||||
features.append(ModelFeature.MULTI_TOOL_CALL)
|
||||
|
||||
stream_function_calling = credentials.get("stream_function_calling", "supported")
|
||||
if stream_function_calling == "supported":
|
||||
features.append(ModelFeature.STREAM_TOOL_CALL)
|
||||
|
||||
vision_support = credentials.get("vision_support", "not_support")
|
||||
if vision_support == "support":
|
||||
features.append(ModelFeature.VISION)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
features=features,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")),
|
||||
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
help=I18nObject(
|
||||
en_US="Kernel sampling threshold. Used to determine the randomness of the results."
|
||||
"The higher the value, the stronger the randomness."
|
||||
"The higher the possibility of getting different answers to the same question.",
|
||||
zh_Hans="核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("temperature", 0.7)),
|
||||
min=0,
|
||||
max=2,
|
||||
precision=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||
help=I18nObject(
|
||||
en_US="The probability threshold of the nucleus sampling method during the generation process."
|
||||
"The larger the value is, the higher the randomness of generation will be."
|
||||
"The smaller the value is, the higher the certainty of generation will be.",
|
||||
zh_Hans="生成过程中核采样方法概率阈值。取值越大,生成的随机性越高;取值越小,生成的确定性越高。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("top_p", 1)),
|
||||
min=0,
|
||||
max=1,
|
||||
precision=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
||||
label=I18nObject(en_US="Frequency Penalty", zh_Hans="频率惩罚"),
|
||||
help=I18nObject(
|
||||
en_US="For controlling the repetition rate of words used by the model."
|
||||
"Increasing this can reduce the repetition of the same words in the model's output.",
|
||||
zh_Hans="用于控制模型已使用字词的重复率。 提高此项可以降低模型在输出中重复相同字词的重复度。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("frequency_penalty", 0)),
|
||||
min=-2,
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.PRESENCE_PENALTY.value,
|
||||
label=I18nObject(en_US="Presence Penalty", zh_Hans="存在惩罚"),
|
||||
help=I18nObject(
|
||||
en_US="Used to control the repetition rate when generating models."
|
||||
"Increasing this can reduce the repetition rate of model generation.",
|
||||
zh_Hans="用于控制模型生成时的重复度。提高此项可以降低模型生成的重复度。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("presence_penalty", 0)),
|
||||
min=-2,
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.MAX_TOKENS.value,
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||
help=I18nObject(
|
||||
en_US="Maximum length of tokens for the model response.", zh_Hans="模型回答的tokens的最大长度。"
|
||||
),
|
||||
type=ParameterType.INT,
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get("max_tokens_to_sample", 4096)),
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
output=Decimal(credentials.get("output_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
)
|
||||
|
||||
if credentials["mode"] == "chat":
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
||||
elif credentials["mode"] == "completion":
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
return entity
|
||||
|
||||
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers
|
||||
# following OpenAI's API standard.
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm completion model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept-Charset": "utf-8",
|
||||
}
|
||||
extra_headers = credentials.get("extra_headers")
|
||||
if extra_headers is not None:
|
||||
headers = {
|
||||
**headers,
|
||||
**extra_headers,
|
||||
}
|
||||
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials["endpoint_url"]
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
data = {"model": model, "stream": stream, **model_parameters}
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, "chat/completions")
|
||||
data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
endpoint_url = urljoin(endpoint_url, "completions")
|
||||
data["prompt"] = prompt_messages[0].content
|
||||
else:
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
|
||||
# annotate tools with names, descriptions, etc.
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
formatted_tools = []
|
||||
if tools:
|
||||
if function_calling_type == "function_call":
|
||||
data["functions"] = [
|
||||
{"name": tool.name, "description": tool.description, "parameters": tool.parameters}
|
||||
for tool in tools
|
||||
]
|
||||
elif function_calling_type == "tool_call":
|
||||
data["tool_choice"] = "auto"
|
||||
|
||||
for tool in tools:
|
||||
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
||||
|
||||
data["tools"] = formatted_tools
|
||||
|
||||
if stop:
|
||||
data["stop"] = stop
|
||||
|
||||
if user:
|
||||
data["user"] = user
|
||||
|
||||
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream)
|
||||
|
||||
if response.encoding is None or response.encoding == "ISO-8859-1":
|
||||
response.encoding = "utf-8"
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param response: streamed response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ""
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(
|
||||
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = usage and usage.get("prompt_tokens")
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = usage and usage.get("completion_tokens")
|
||||
if completion_tokens is None:
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
id=id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
)
|
||||
|
||||
# delimiter for stream response, need unicode_escape
|
||||
import codecs
|
||||
|
||||
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
||||
delimiter = codecs.decode(delimiter, "unicode_escape")
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_call_id: str):
|
||||
if not tool_call_id:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
finish_reason = None # The default value of finish_reason is None
|
||||
message_id, usage = None, None
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
chunk = chunk.strip()
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_json: dict = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered.",
|
||||
usage=usage,
|
||||
)
|
||||
break
|
||||
if chunk_json:
|
||||
if u := chunk_json.get("usage"):
|
||||
usage = u
|
||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json["choices"][0]
|
||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||
message_id = chunk_json.get("id")
|
||||
chunk_index += 1
|
||||
|
||||
if "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
delta_content = delta.get("content")
|
||||
|
||||
assistant_message_tool_calls = None
|
||||
|
||||
if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call":
|
||||
assistant_message_tool_calls = delta.get("tool_calls", None)
|
||||
elif (
|
||||
"function_call" in delta
|
||||
and credentials.get("function_calling_type", "no_call") == "function_call"
|
||||
):
|
||||
assistant_message_tool_calls = [
|
||||
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
|
||||
]
|
||||
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content,
|
||||
)
|
||||
|
||||
# reset tool calls
|
||||
tool_calls = []
|
||||
full_assistant_content += delta_content
|
||||
elif "text" in choice:
|
||||
choice_text = choice.get("text", "")
|
||||
if choice_text == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
|
||||
),
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
response_json: dict = response.json()
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
output = response_json["choices"][0]
|
||||
message_id = response_json.get("id")
|
||||
|
||||
response_content = ""
|
||||
tool_calls = None
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
if completion_type is LLMMode.CHAT:
|
||||
response_content = output.get("message", {})["content"]
|
||||
if function_calling_type == "tool_call":
|
||||
tool_calls = output.get("message", {}).get("tool_calls")
|
||||
elif function_calling_type == "function_call":
|
||||
tool_calls = output.get("message", {}).get("function_call")
|
||||
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
response_content = output["text"]
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
|
||||
|
||||
if tool_calls:
|
||||
if function_calling_type == "tool_call":
|
||||
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
||||
elif function_calling_type == "function_call":
|
||||
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
|
||||
|
||||
usage = response_json.get("usage")
|
||||
if usage:
|
||||
# transform usage
|
||||
prompt_tokens = usage["prompt_tokens"]
|
||||
completion_tokens = usage["completion_tokens"]
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(model, assistant_message.content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
id=message_id,
|
||||
model=response_json["model"],
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI API format
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
if function_calling_type == "tool_call":
|
||||
message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls]
|
||||
elif function_calling_type == "function_call":
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
if function_calling_type == "tool_call":
|
||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
||||
elif function_calling_type == "function_call":
|
||||
message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name and message_dict.get("role", "") != "tool":
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_string(
|
||||
self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Approximate num tokens for model with gpt2 tokenizer.
|
||||
|
||||
:param model: model name
|
||||
:param text: prompt text
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
full_text = text
|
||||
else:
|
||||
full_text = ""
|
||||
for message_content in text:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
full_text += message_content.data
|
||||
|
||||
num_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
credentials: Optional[dict] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Approximate num tokens with GPT2 tokenizer.
|
||||
"""
|
||||
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m, credentials) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(t_key)
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(f_key)
|
||||
num_tokens += self._get_num_tokens_by_gpt2(f_value)
|
||||
else:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(t_key)
|
||||
num_tokens += self._get_num_tokens_by_gpt2(t_value)
|
||||
else:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
||||
"""
|
||||
Calculate num tokens for tool calling with tiktoken package.
|
||||
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("type")
|
||||
num_tokens += self._get_num_tokens_by_gpt2("function")
|
||||
num_tokens += self._get_num_tokens_by_gpt2("function")
|
||||
|
||||
# calculate num tokens for function object
|
||||
num_tokens += self._get_num_tokens_by_gpt2("name")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(tool.name)
|
||||
num_tokens += self._get_num_tokens_by_gpt2("description")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(tool.description)
|
||||
parameters = tool.parameters
|
||||
num_tokens += self._get_num_tokens_by_gpt2("parameters")
|
||||
if "title" in parameters:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("title")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title"))
|
||||
num_tokens += self._get_num_tokens_by_gpt2("type")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type"))
|
||||
if "properties" in parameters:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("properties")
|
||||
for key, value in parameters.get("properties").items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(key)
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
||||
if field_key == "enum":
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += self._get_num_tokens_by_gpt2(enum_field)
|
||||
else:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(field_value))
|
||||
if "required" in parameters:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("required")
|
||||
for required_field in parameters["required"]:
|
||||
num_tokens += 3
|
||||
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
:param response_tool_calls: response tool calls
|
||||
:return: list of tool calls
|
||||
"""
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.get("function", {}).get("name", ""),
|
||||
arguments=response_tool_call.get("function", {}).get("arguments", ""),
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.get("id", ""), type=response_tool_call.get("type", ""), function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _extract_response_function_call(self, response_function_call) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract function call from response
|
||||
|
||||
:param response_function_call: response function call
|
||||
:return: tool call
|
||||
"""
|
||||
tool_call = None
|
||||
if response_function_call:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_function_call.get("name", ""), arguments=response_function_call.get("arguments", "")
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_function_call.get("id", ""), type="function", function=function
|
||||
)
|
||||
|
||||
return tool_call
|
Binary file not shown.
After Width: | Height: | Size: 11 KiB |
|
@ -0,0 +1,3 @@
|
|||
<svg width="1200" height="925" viewBox="0 0 1200 925" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M780.152 250.999L907.882 462.174C907.882 462.174 880.925 510.854 867.43 535.21C834.845 594.039 764.171 612.49 710.442 508.333L420.376 0H0L459.926 803.307C552.303 964.663 787.366 964.663 879.743 803.307C989.874 610.952 1089.87 441.97 1200 249.646L1052.28 0H639.519L780.152 250.999Z" fill="#3366FF"/>
|
||||
</svg>
|
After Width: | Height: | Size: 417 B |
83
api/core/model_runtime/model_providers/vessl_ai/llm/llm.py
Normal file
83
api/core/model_runtime/model_providers/vessl_ai/llm/llm.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
PriceConfig,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
features = []
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
features=features,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("temperature", 0.7)),
|
||||
min=0,
|
||||
max=2,
|
||||
precision=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("top_p", 1)),
|
||||
min=0,
|
||||
max=1,
|
||||
precision=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_K.value,
|
||||
label=I18nObject(en_US="Top K"),
|
||||
type=ParameterType.INT,
|
||||
default=int(credentials.get("top_k", 50)),
|
||||
min=-2147483647,
|
||||
max=2147483647,
|
||||
precision=0,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.MAX_TOKENS.value,
|
||||
label=I18nObject(en_US="Max Tokens"),
|
||||
type=ParameterType.INT,
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get("max_tokens_to_sample", 4096)),
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
output=Decimal(credentials.get("output_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
)
|
||||
|
||||
if credentials["mode"] == "chat":
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
||||
elif credentials["mode"] == "completion":
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
return entity
|
10
api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py
Normal file
10
api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VesslAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
|
@ -0,0 +1,56 @@
|
|||
provider: vessl_ai
|
||||
label:
|
||||
en_US: vessl_ai
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#F1EFED"
|
||||
help:
|
||||
title:
|
||||
en_US: How to deploy VESSL AI LLM Model Endpoint
|
||||
url:
|
||||
en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
credential_form_schemas:
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
en_US: endpoint url
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
en_US: Enter the url of your endpoint url
|
||||
- variable: api_key
|
||||
required: true
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your VESSL AI secret key
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
|
@ -413,7 +413,7 @@ class PluginModelManager(BasePluginManager):
|
|||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/model/voices",
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
|
||||
type=PluginVoicesResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
|
@ -434,8 +434,10 @@ class PluginModelManager(BasePluginManager):
|
|||
)
|
||||
|
||||
for resp in response:
|
||||
voices = []
|
||||
for voice in resp.voices:
|
||||
return [{"name": voice.name, "value": voice.value}]
|
||||
voices.append({"name": voice.name, "value": voice.value})
|
||||
return voices
|
||||
|
||||
return []
|
||||
|
||||
|
|
|
@ -34,6 +34,8 @@ class RetrievalService:
|
|||
reranking_mode: Optional[str] = "reranking_model",
|
||||
weights: Optional[dict] = None,
|
||||
):
|
||||
if not query:
|
||||
return []
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
|
|
|
@ -3,11 +3,13 @@ import time
|
|||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymochow import MochowClient
|
||||
from pymochow.auth.bce_credentials import BceCredentials
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
|
||||
from pymochow.exception import ServerError
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
|
||||
|
||||
|
@ -116,6 +118,7 @@ class BaiduVector(BaseVector):
|
|||
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
|
@ -149,7 +152,13 @@ class BaiduVector(BaseVector):
|
|||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
self._db.drop_table(table_name=self._collection_name)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.TABLE_NOT_EXIST:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
def _init_client(self, config) -> MochowClient:
|
||||
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
|
||||
|
@ -166,7 +175,14 @@ class BaiduVector(BaseVector):
|
|||
if exists:
|
||||
return self._client.database(self._client_config.database)
|
||||
else:
|
||||
return self._client.create_database(database_name=self._client_config.database)
|
||||
try:
|
||||
self._client.create_database(database_name=self._client_config.database)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.DB_ALREADY_EXIST:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
return
|
||||
|
||||
def _table_existed(self) -> bool:
|
||||
tables = self._db.list_table()
|
||||
|
@ -175,7 +191,7 @@ class BaiduVector(BaseVector):
|
|||
def _create_table(self, dimension: int) -> None:
|
||||
# Try to grab distributed lock and create table
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
with redis_client.lock(lock_name, timeout=60):
|
||||
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(table_exist_cache_key):
|
||||
return
|
||||
|
@ -238,14 +254,13 @@ class BaiduVector(BaseVector):
|
|||
description="Table for Dify",
|
||||
)
|
||||
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
|
||||
# Wait for table created
|
||||
while True:
|
||||
time.sleep(1)
|
||||
table = self._db.describe_table(self._collection_name)
|
||||
if table.state == TableState.NORMAL:
|
||||
break
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class BaiduVectorFactory(AbstractVectorFactory):
|
||||
|
|
|
@ -37,7 +37,7 @@ class TidbService:
|
|||
}
|
||||
|
||||
spending_limit = {
|
||||
"monthly": 100,
|
||||
"monthly": dify_config.TIDB_SPEND_LIMIT,
|
||||
}
|
||||
password = str(uuid.uuid4()).replace("-", "")[:16]
|
||||
display_name = str(uuid.uuid4()).replace("-", "")[:16]
|
||||
|
|
|
@ -27,16 +27,15 @@ class RerankModelRunner(BaseRerankRunner):
|
|||
:return:
|
||||
"""
|
||||
docs = []
|
||||
doc_id = []
|
||||
doc_id = set()
|
||||
unique_documents = []
|
||||
dify_documents = [item for item in documents if item.provider == "dify"]
|
||||
external_documents = [item for item in documents if item.provider == "external"]
|
||||
for document in dify_documents:
|
||||
if document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.append(document.metadata["doc_id"])
|
||||
for document in documents:
|
||||
if document.provider == "dify" and document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.add(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
for document in external_documents:
|
||||
elif document.provider == "external":
|
||||
if document not in unique_documents:
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
|
|
|
@ -116,10 +116,8 @@ class ToolInvokeMessage(BaseModel):
|
|||
|
||||
class VariableMessage(BaseModel):
|
||||
variable_name: str = Field(..., description="The name of the variable")
|
||||
variable_value: str = Field(...,
|
||||
description="The value of the variable")
|
||||
stream: bool = Field(
|
||||
default=False, description="Whether the variable is streamed")
|
||||
variable_value: str = Field(..., description="The value of the variable")
|
||||
stream: bool = Field(default=False, description="Whether the variable is streamed")
|
||||
|
||||
@field_validator("variable_value", mode="before")
|
||||
@classmethod
|
||||
|
@ -133,8 +131,7 @@ class ToolInvokeMessage(BaseModel):
|
|||
# if stream is true, the value must be a string
|
||||
if values.get("stream"):
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.")
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
|
||||
return value
|
||||
|
||||
|
@ -271,8 +268,7 @@ class ToolParameter(BaseModel):
|
|||
return str(value)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"The tool parameter value {value} is not in correct type.")
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {self.as_normal_type()}.")
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
|
@ -280,17 +276,12 @@ class ToolParameter(BaseModel):
|
|||
LLM = "llm" # will be set by LLM
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(...,
|
||||
description="The label presented to the user")
|
||||
human_description: Optional[I18nObject] = Field(
|
||||
default=None, description="The description presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(
|
||||
default=None, description="The placeholder presented to the user")
|
||||
type: ToolParameterType = Field(...,
|
||||
description="The type of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
scope: AppSelectorScope | ModelConfigScope | None = None
|
||||
form: ToolParameterForm = Field(...,
|
||||
description="The form of the parameter, schema/form/llm")
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
required: Optional[bool] = False
|
||||
default: Optional[Union[float, int, str]] = None
|
||||
|
@ -346,8 +337,7 @@ class ToolParameter(BaseModel):
|
|||
class ToolProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
description: I18nObject = Field(...,
|
||||
description="The description of the tool")
|
||||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||
|
@ -365,8 +355,7 @@ class ToolIdentity(BaseModel):
|
|||
|
||||
|
||||
class ToolDescription(BaseModel):
|
||||
human: I18nObject = Field(...,
|
||||
description="The description presented to the user")
|
||||
human: I18nObject = Field(..., description="The description presented to the user")
|
||||
llm: str = Field(..., description="The description presented to the LLM")
|
||||
|
||||
|
||||
|
@ -375,8 +364,7 @@ class ToolEntity(BaseModel):
|
|||
parameters: list[ToolParameter] = Field(default_factory=list)
|
||||
description: Optional[ToolDescription] = None
|
||||
output_schema: Optional[dict] = None
|
||||
has_runtime_parameters: bool = Field(
|
||||
default=False, description="Whether the tool has runtime parameters")
|
||||
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
@ -403,10 +391,8 @@ class WorkflowToolParameterConfiguration(BaseModel):
|
|||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
description: str = Field(...,
|
||||
description="The description of the parameter")
|
||||
form: ToolParameter.ToolParameterForm = Field(
|
||||
..., description="The form of the parameter")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
|
||||
|
||||
|
||||
class ToolInvokeMeta(BaseModel):
|
||||
|
@ -414,8 +400,7 @@ class ToolInvokeMeta(BaseModel):
|
|||
Tool invoke meta
|
||||
"""
|
||||
|
||||
time_cost: float = Field(...,
|
||||
description="The time cost of the tool invoke")
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: Optional[str] = None
|
||||
tool_config: Optional[dict] = None
|
||||
|
||||
|
@ -474,5 +459,4 @@ class ToolProviderID:
|
|||
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
|
||||
raise ValueError("Invalid plugin id")
|
||||
|
||||
self.organization, self.plugin_name, self.provider_name = value.split(
|
||||
"/")
|
||||
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class AliYuqueTool:
|
||||
# yuque service url
|
||||
server_url = "https://www.yuque.com"
|
||||
|
||||
@staticmethod
|
||||
def auth(token):
|
||||
session = requests.Session()
|
||||
session.headers.update({"Accept": "application/json", "X-Auth-Token": token})
|
||||
login = session.request("GET", AliYuqueTool.server_url + "/api/v2/user")
|
||||
login.raise_for_status()
|
||||
resp = login.json()
|
||||
return resp
|
||||
|
||||
def request(self, method: str, token, tool_parameters: dict[str, Any], path: str) -> str:
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
session = requests.Session()
|
||||
session.headers.update({"accept": "application/json", "X-Auth-Token": token})
|
||||
new_params = {**tool_parameters}
|
||||
|
||||
replacements = {k: v for k, v in new_params.items() if f"{{{k}}}" in path}
|
||||
|
||||
for key, value in replacements.items():
|
||||
path = path.replace(f"{{{key}}}", str(value))
|
||||
del new_params[key]
|
||||
|
||||
if method.upper() in {"POST", "PUT"}:
|
||||
session.headers.update(
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
response = session.request(method.upper(), self.server_url + path, json=new_params)
|
||||
else:
|
||||
response = session.request(method, self.server_url + path, params=new_params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
|
@ -1,15 +0,0 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AliYuqueCreateDocumentTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
return self.create_text_message(self.request("POST", token, tool_parameters, "/api/v2/repos/{book_id}/docs"))
|
|
@ -1,99 +0,0 @@
|
|||
identity:
|
||||
name: aliyuque_create_document
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Create Document
|
||||
zh_Hans: 创建文档
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Creates a new document within a knowledge base without automatic addition to the table of contents. Requires a subsequent call to the "knowledge base directory update API". Supports setting visibility, format, and content. # 接口英文描述
|
||||
zh_Hans: 在知识库中创建新文档,但不会自动加入目录,需额外调用“知识库目录更新接口”。允许设置公开性、格式及正文内容。
|
||||
llm: Creates docs in a KB.
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Knowledge Base ID
|
||||
zh_Hans: 知识库ID
|
||||
human_description:
|
||||
en_US: The unique identifier of the knowledge base where the document will be created.
|
||||
zh_Hans: 文档将被创建的知识库的唯一标识。
|
||||
llm_description: ID of the target knowledge base.
|
||||
|
||||
- name: title
|
||||
type: string
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: Title
|
||||
zh_Hans: 标题
|
||||
human_description:
|
||||
en_US: The title of the document, defaults to 'Untitled' if not provided.
|
||||
zh_Hans: 文档标题,默认为'无标题'如未提供。
|
||||
llm_description: Title of the document, defaults to 'Untitled'.
|
||||
|
||||
- name: public
|
||||
type: select
|
||||
required: false
|
||||
form: llm
|
||||
options:
|
||||
- value: 0
|
||||
label:
|
||||
en_US: Private
|
||||
zh_Hans: 私密
|
||||
- value: 1
|
||||
label:
|
||||
en_US: Public
|
||||
zh_Hans: 公开
|
||||
- value: 2
|
||||
label:
|
||||
en_US: Enterprise-only
|
||||
zh_Hans: 企业内公开
|
||||
label:
|
||||
en_US: Visibility
|
||||
zh_Hans: 公开性
|
||||
human_description:
|
||||
en_US: Document visibility (0 Private, 1 Public, 2 Enterprise-only).
|
||||
zh_Hans: 文档可见性(0 私密, 1 公开, 2 企业内公开)。
|
||||
llm_description: Doc visibility options, 0-private, 1-public, 2-enterprise.
|
||||
|
||||
- name: format
|
||||
type: select
|
||||
required: false
|
||||
form: llm
|
||||
options:
|
||||
- value: markdown
|
||||
label:
|
||||
en_US: markdown
|
||||
zh_Hans: markdown
|
||||
- value: html
|
||||
label:
|
||||
en_US: html
|
||||
zh_Hans: html
|
||||
- value: lake
|
||||
label:
|
||||
en_US: lake
|
||||
zh_Hans: lake
|
||||
label:
|
||||
en_US: Content Format
|
||||
zh_Hans: 内容格式
|
||||
human_description:
|
||||
en_US: Format of the document content (markdown, HTML, Lake).
|
||||
zh_Hans: 文档内容格式(markdown, HTML, Lake)。
|
||||
llm_description: Content format choices, markdown, HTML, Lake.
|
||||
|
||||
- name: body
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Body Content
|
||||
zh_Hans: 正文内容
|
||||
human_description:
|
||||
en_US: The actual content of the document.
|
||||
zh_Hans: 文档的实际内容。
|
||||
llm_description: Content of the document.
|
|
@ -1,17 +0,0 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AliYuqueDeleteDocumentTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
return self.create_text_message(
|
||||
self.request("DELETE", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}")
|
||||
)
|
|
@ -1,37 +0,0 @@
|
|||
identity:
|
||||
name: aliyuque_delete_document
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Delete Document
|
||||
zh_Hans: 删除文档
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Delete Document
|
||||
zh_Hans: 根据id删除文档
|
||||
llm: Delete document.
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Knowledge Base ID
|
||||
zh_Hans: 知识库ID
|
||||
human_description:
|
||||
en_US: The unique identifier of the knowledge base where the document will be created.
|
||||
zh_Hans: 文档将被创建的知识库的唯一标识。
|
||||
llm_description: ID of the target knowledge base.
|
||||
|
||||
- name: id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Document ID or Path
|
||||
zh_Hans: 文档 ID or 路径
|
||||
human_description:
|
||||
en_US: Document ID or path.
|
||||
zh_Hans: 文档 ID or 路径。
|
||||
llm_description: Document ID or path.
|
|
@ -1,17 +0,0 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AliYuqueDescribeBookIndexPageTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
return self.create_text_message(
|
||||
self.request("GET", token, tool_parameters, "/api/v2/repos/{group_login}/{book_slug}/index_page")
|
||||
)
|
|
@ -1,15 +0,0 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
return self.create_text_message(self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/toc"))
|
|
@ -1,25 +0,0 @@
|
|||
identity:
|
||||
name: aliyuque_describe_book_table_of_contents
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Get Book's Table of Contents
|
||||
zh_Hans: 获取知识库的目录
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Get Book's Table of Contents.
|
||||
zh_Hans: 获取知识库的目录。
|
||||
llm: Get Book's Table of Contents.
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Book ID
|
||||
zh_Hans: 知识库 ID
|
||||
human_description:
|
||||
en_US: Book ID.
|
||||
zh_Hans: 知识库 ID。
|
||||
llm_description: Book ID.
|
|
@ -1,52 +0,0 @@
|
|||
import json
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
new_params = {**tool_parameters}
|
||||
token = new_params.pop("token")
|
||||
if not token or token.lower() == "none":
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
new_params = {**tool_parameters}
|
||||
url = new_params.pop("url")
|
||||
if not url or not url.startswith("http"):
|
||||
raise Exception("url is not valid")
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
if len(path_parts) < 3:
|
||||
raise Exception("url is not correct")
|
||||
doc_id = path_parts[-1]
|
||||
book_slug = path_parts[-2]
|
||||
group_id = path_parts[-3]
|
||||
|
||||
new_params["group_login"] = group_id
|
||||
new_params["book_slug"] = book_slug
|
||||
index_page = json.loads(
|
||||
self.request("GET", token, new_params, "/api/v2/repos/{group_login}/{book_slug}/index_page")
|
||||
)
|
||||
book_id = index_page.get("data", {}).get("book", {}).get("id")
|
||||
if not book_id:
|
||||
raise Exception(f"can not parse book_id from {index_page}")
|
||||
new_params["book_id"] = book_id
|
||||
new_params["id"] = doc_id
|
||||
data = self.request("GET", token, new_params, "/api/v2/repos/{book_id}/docs/{id}")
|
||||
data = json.loads(data)
|
||||
body_only = tool_parameters.get("body_only") or ""
|
||||
if body_only.lower() == "true":
|
||||
return self.create_text_message(data.get("data").get("body"))
|
||||
else:
|
||||
raw = data.get("data")
|
||||
del raw["body_lake"]
|
||||
del raw["body_html"]
|
||||
return self.create_text_message(json.dumps(data))
|
|
@ -1,17 +0,0 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AliYuqueDescribeDocumentsTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
return self.create_text_message(
|
||||
self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}")
|
||||
)
|
|
@ -1,38 +0,0 @@
|
|||
identity:
|
||||
name: aliyuque_describe_documents
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Get Doc Detail
|
||||
zh_Hans: 获取文档详情
|
||||
icon: icon.svg
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: Retrieves detailed information of a specific document identified by its ID or path within a knowledge base.
|
||||
zh_Hans: 根据知识库ID和文档ID或路径获取文档详细信息。
|
||||
llm: Fetches detailed doc info using ID/path from a knowledge base; supports doc lookup in Yuque.
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Knowledge Base ID
|
||||
zh_Hans: 知识库 ID
|
||||
human_description:
|
||||
en_US: Identifier for the knowledge base where the document resides.
|
||||
zh_Hans: 文档所属知识库的唯一标识。
|
||||
llm_description: ID of the knowledge base holding the document.
|
||||
|
||||
- name: id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Document ID or Path
|
||||
zh_Hans: 文档 ID 或路径
|
||||
human_description:
|
||||
en_US: The unique identifier or path of the document to retrieve.
|
||||
zh_Hans: 需要获取的文档的ID或其在知识库中的路径。
|
||||
llm_description: Unique doc ID or its path for retrieval.
|
|
@ -1,21 +0,0 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
|
||||
doc_ids = tool_parameters.get("doc_ids")
|
||||
if doc_ids:
|
||||
doc_ids = [int(doc_id.strip()) for doc_id in doc_ids.split(",")]
|
||||
tool_parameters["doc_ids"] = doc_ids
|
||||
|
||||
return self.create_text_message(self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/toc"))
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user