diff --git a/.github/linters/.isort.cfg b/.github/linters/.isort.cfg new file mode 100644 index 0000000000..8913d9628a --- /dev/null +++ b/.github/linters/.isort.cfg @@ -0,0 +1,2 @@ +[settings] +line_length=120 diff --git a/api/app.py b/api/app.py index 4e5fd1f24b..b7234b6a17 100644 --- a/api/app.py +++ b/api/app.py @@ -13,30 +13,29 @@ if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': import langchain langchain.verbose = True -import time -import logging import json +import logging import threading +import time +import warnings -from flask import Flask, request, Response -from flask_cors import CORS - -from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ - ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider +from commands import register_commands +from config import CloudEditionConfig, Config +from events import event_handlers +from extensions import (ext_celery, ext_code_based_extension, ext_database, ext_hosting_provider, ext_login, ext_mail, + ext_migrate, ext_redis, ext_sentry, ext_storage) from extensions.ext_database import db from extensions.ext_login import login_manager - +from flask import Flask, Response, request +from flask_cors import CORS +from libs.passport import PassportService # DO NOT REMOVE BELOW -from models import model, account, dataset, web, task, source, tool -from events import event_handlers +from models import account, dataset, model, source, task, tool, web +from services.account_service import AccountService + # DO NOT REMOVE ABOVE -from config import Config, CloudEditionConfig -from commands import register_commands -from services.account_service import AccountService -from libs.passport import PassportService -import warnings warnings.simplefilter("ignore", ResourceWarning) # fix windows platform @@ -136,10 +135,10 @@ def unauthorized_handler(): # register blueprint routers def register_blueprints(app): - from controllers.service_api import bp as service_api_bp - from controllers.web import bp as web_bp from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp CORS(service_api_bp, allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], diff --git a/api/commands.py b/api/commands.py index 9e4681b42f..a9a04c9f13 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,7 +1,9 @@ +import base64 import datetime import json import math import random +import secrets import string import threading import time @@ -9,26 +11,22 @@ import uuid import click import qdrant_client -from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType -from tqdm import tqdm -from flask import current_app, Flask -from werkzeug.exceptions import NotFound - from core.embedding.cached_embedding import CacheEmbedding from core.index.index import IndexBuilder from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from libs.password import password_pattern, valid_password, hash_password -from libs.helper import email as email_validate from extensions.ext_database import db +from flask import Flask, current_app +from libs.helper import email as email_validate +from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair from models.account import InvitationCode, Tenant, TenantAccountJoin -from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding -from models.model import Account, AppModelConfig, App, MessageAnnotation, Message -import secrets -import base64 - -from models.provider import Provider, ProviderType, ProviderQuotaType, ProviderModel +from models.dataset import Dataset, DatasetCollectionBinding, DatasetQuery, Document +from models.model import Account, App, AppModelConfig, Message, MessageAnnotation +from models.provider import Provider, ProviderModel, ProviderQuotaType, ProviderType +from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType +from tqdm import tqdm +from werkzeug.exceptions import NotFound @click.command('reset-password', help='Reset the account password.') @@ -362,7 +360,7 @@ def create_qdrant_indexes(): model_provider=model_provider) embeddings = CacheEmbedding(embedding_model) - from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig + from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex index = QdrantVectorIndex( dataset=dataset, @@ -433,7 +431,7 @@ def update_qdrant_indexes(): model_provider=model_provider) embeddings = CacheEmbedding(embedding_model) - from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig + from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex index = QdrantVectorIndex( dataset=dataset, @@ -558,7 +556,7 @@ def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: db.session.add(dataset_collection_binding) db.session.commit() - from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig + from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex index = QdrantVectorIndex( dataset=dataset, diff --git a/api/config.py b/api/config.py index 049ee3c46c..fc5e95ed57 100644 --- a/api/config.py +++ b/api/config.py @@ -3,7 +3,6 @@ import os import dotenv - dotenv.load_dotenv() DEFAULTS = { diff --git a/api/constants/model_template.py b/api/constants/model_template.py index c35a0b38d6..055745dbf9 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,6 +1,6 @@ import json -from models.model import AppModelConfig, App +from models.model import App, AppModelConfig model_templates = { # completion default mode diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 95a9fe3a08..1394452c80 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -1,29 +1,22 @@ from flask import Blueprint - from libs.external_api import ExternalApi bp = Blueprint('console', __name__, url_prefix='/console/api') api = ExternalApi(bp) # Import other controllers -from . import extension, setup, version, apikey, admin, feature - +from . import admin, apikey, extension, feature, setup, version # Import app controllers -from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation - +from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, + model_config, site, statistic) # Import auth controllers -from .auth import login, oauth, data_source_oauth, activate - -# Import datasets controllers -from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source - -# Import workspace controllers -from .workspace import workspace, members, model_providers, account, tool_providers, models - -# Import explore controllers -from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio - -# Import universal chat controllers -from .universal_chat import chat, conversation, message, parameter, audio - +from .auth import activate, data_source_oauth, login, oauth from .billing import billing +# Import datasets controllers +from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing +# Import explore controllers +from .explore import audio, completion, conversation, installed_app, message, parameter, recommended_app, saved_message +# Import universal chat controllers +from .universal_chat import audio, chat, conversation, message, parameter +# Import workspace controllers +from .workspace import account, members, model_providers, models, tool_providers, workspace diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index c14b4e60ab..963d2a0e1e 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,15 +1,14 @@ import os from functools import wraps -from flask import request -from flask_restful import Resource, reqparse -from werkzeug.exceptions import NotFound, Unauthorized - from controllers.console import api from controllers.console.wraps import only_edition_cloud from extensions.ext_database import db +from flask import request +from flask_restful import Resource, reqparse from libs.helper import supported_language -from models.model import RecommendedApp, App, InstalledApp +from models.model import App, InstalledApp, RecommendedApp +from werkzeug.exceptions import NotFound, Unauthorized def admin_required(view): diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 22c9f85f45..23b11afe1c 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,17 +1,16 @@ -from flask_login import current_user -from libs.login import login_required import flask_restful -from flask_restful import Resource, fields, marshal_with -from werkzeug.exceptions import Forbidden - from extensions.ext_database import db -from models.model import App, ApiToken +from flask_login import current_user +from flask_restful import Resource, fields, marshal_with +from libs.helper import TimestampField +from libs.login import login_required from models.dataset import Dataset +from models.model import ApiToken, App +from werkzeug.exceptions import Forbidden from . import api from .setup import setup_required from .wraps import account_initialization_required -from libs.helper import TimestampField api_key_fields = { 'id': fields.String, diff --git a/api/controllers/console/app/__init__.py b/api/controllers/console/app/__init__.py index f0c7956e0f..b0b07517f1 100644 --- a/api/controllers/console/app/__init__.py +++ b/api/controllers/console/app/__init__.py @@ -1,9 +1,8 @@ -from flask_login import current_user -from werkzeug.exceptions import NotFound - from controllers.console.app.error import AppUnavailableError from extensions.ext_database import db +from flask_login import current_user from models.model import App +from werkzeug.exceptions import NotFound def _get_app(app_id, mode=None): diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c92f0570dc..c7693fb950 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,11 +1,11 @@ -from flask_restful import Resource, reqparse - from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from flask_restful import Resource, reqparse from libs.login import login_required from services.advanced_prompt_template_service import AdvancedPromptTemplateService + class AdvancedPromptTemplateList(Resource): @setup_required diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index af7f52e970..439ed07345 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,18 +1,17 @@ -from flask_login import current_user -from flask_restful import Resource, reqparse, marshal_with, marshal -from werkzeug.exceptions import Forbidden - from controllers.console import api from controllers.console.app.error import NoFileUploadedError from controllers.console.datasets.error import TooManyFilesError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_redis import redis_client -from fields.annotation_fields import annotation_list_fields, annotation_hit_history_list_fields, annotation_fields, \ - annotation_hit_history_fields +from fields.annotation_fields import (annotation_fields, annotation_hit_history_fields, + annotation_hit_history_list_fields, annotation_list_fields) +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal, marshal_with, reqparse from libs.login import login_required from services.annotation_service import AppAnnotationService -from flask import request +from werkzeug.exceptions import Forbidden class AnnotationReplyActionApi(Resource): diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 7f65023734..6ae0ef4806 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -3,27 +3,25 @@ import json import logging from datetime import datetime -from flask_login import current_user - -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.provider_manager import ProviderManager -from libs.login import login_required -from flask_restful import Resource, reqparse, marshal_with, abort, inputs -from werkzeug.exceptions import Forbidden - -from constants.model_template import model_templates, demo_model_templates +from constants.model_template import demo_model_templates, model_templates from controllers.console import api from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.provider_manager import ProviderManager from events.app_event import app_was_created, app_was_deleted -from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \ - app_detail_fields_with_site from extensions.ext_database import db +from fields.app_fields import (app_detail_fields, app_detail_fields_with_site, app_pagination_fields, + template_list_fields) +from flask_login import current_user +from flask_restful import Resource, abort, inputs, marshal_with, reqparse +from libs.login import login_required from models.model import App, AppModelConfig, Site from services.app_model_config_service import AppModelConfigService +from werkzeug.exceptions import Forbidden def _get_app(app_id, tenant_id): diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index c78005adf2..ed8b36c00c 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,26 +1,24 @@ # -*- coding:utf-8 -*- import logging -from flask import request - -from core.model_runtime.errors.invoke import InvokeError -from libs.login import login_required -from werkzeug.exceptions import InternalServerError - import services from controllers.console import api from controllers.console.app import _get_app -from controllers.console.app.error import AppUnavailableError, \ - ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \ - ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \ - UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError +from controllers.console.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, + NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, UnsupportedAudioTypeError) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.errors.invoke import InvokeError +from flask import request from flask_restful import Resource +from libs.login import login_required from services.audio_service import AudioService -from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ - UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError +from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) +from werkzeug.exceptions import InternalServerError class ChatMessageAudioApi(Resource): diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 70bd92a25a..50c9825a92 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -4,27 +4,24 @@ import logging from typing import Generator, Union import flask_login -from flask import Response, stream_with_context - -from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import InvokeFrom -from core.model_runtime.errors.invoke import InvokeError -from libs.login import login_required -from werkzeug.exceptions import InternalServerError, NotFound - import services from controllers.console import api from controllers.console.app import _get_app -from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, \ - ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \ - ProviderModelCurrentlyNotSupportError +from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, + ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, + ProviderQuotaExceededError) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError -from libs.helper import uuid_value +from core.application_queue_manager import ApplicationQueueManager +from core.entities.application_entities import InvokeFrom +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.errors.invoke import InvokeError +from flask import Response, stream_with_context from flask_restful import Resource, reqparse - +from libs.helper import uuid_value +from libs.login import login_required from services.completion_service import CompletionService +from werkzeug.exceptions import InternalServerError, NotFound # define completion message api for user diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 66abf63e65..f159f74c71 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,23 +1,22 @@ from datetime import datetime import pytz -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource, reqparse, marshal_with -from flask_restful.inputs import int_range -from sqlalchemy import or_, func -from sqlalchemy.orm import joinedload -from werkzeug.exceptions import NotFound - from controllers.console import api from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \ - conversation_message_detail_fields, conversation_with_summary_pagination_fields -from libs.helper import datetime_string from extensions.ext_database import db -from models.model import Message, MessageAnnotation, Conversation +from fields.conversation_fields import (conversation_detail_fields, conversation_message_detail_fields, + conversation_pagination_fields, conversation_with_summary_pagination_fields) +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from flask_restful.inputs import int_range +from libs.helper import datetime_string +from libs.login import login_required +from models.model import Conversation, Message, MessageAnnotation +from sqlalchemy import func, or_ +from sqlalchemy.orm import joinedload +from werkzeug.exceptions import NotFound class CompletionConversationApi(Resource): diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 48d1d5c0bb..d7a320db99 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,16 +1,14 @@ -from flask_login import current_user - -from core.model_runtime.errors.invoke import InvokeError -from libs.login import login_required -from flask_restful import Resource, reqparse - from controllers.console import api -from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \ - CompletionRequestError, ProviderModelCurrentlyNotSupportError +from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, ProviderQuotaExceededError) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.generator.llm_generator import LLMGenerator -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.model_runtime.errors.invoke import InvokeError +from flask_login import current_user +from flask_restful import Resource, reqparse +from libs.login import login_required class RuleGenerateApi(Resource): diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 499cd0a2bd..db1061f40e 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,34 +1,34 @@ import json import logging -from typing import Union, Generator - -from flask import Response, stream_with_context -from flask_login import current_user -from flask_restful import Resource, reqparse, marshal_with, fields -from flask_restful.inputs import int_range -from werkzeug.exceptions import InternalServerError, NotFound, Forbidden +from typing import Generator, Union from controllers.console import api from controllers.console.app import _get_app -from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \ - AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError +from controllers.console.app.error import (AppMoreLikeThisDisabledError, CompletionRequestError, + ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, + ProviderQuotaExceededError) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.entities.application_entities import InvokeFrom -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from libs.login import login_required -from fields.conversation_fields import message_detail_fields, annotation_fields +from extensions.ext_database import db +from fields.conversation_fields import annotation_fields, message_detail_fields +from flask import Response, stream_with_context +from flask_login import current_user +from flask_restful import Resource, fields, marshal_with, reqparse +from flask_restful.inputs import int_range from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination -from extensions.ext_database import db -from models.model import MessageAnnotation, Conversation, Message, MessageFeedback +from libs.login import login_required +from models.model import Conversation, Message, MessageAnnotation, MessageFeedback from services.annotation_service import AppAnnotationService from services.completion_service import CompletionService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError from services.message_service import MessageService +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound class ChatMessageListApi(Resource): diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 7ed92a6531..d447bfa756 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,16 +1,15 @@ # -*- coding:utf-8 -*- -from flask import request -from flask_restful import Resource -from flask_login import current_user - from controllers.console import api from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from libs.login import login_required from events.app_event import app_model_config_was_updated from extensions.ext_database import db +from flask import request +from flask_login import current_user +from flask_restful import Resource +from libs.login import login_required from models.model import AppModelConfig from services.app_model_config_service import AppModelConfigService diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 1812d9e190..93b6a4eecf 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,17 +1,16 @@ # -*- coding:utf-8 -*- -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource, reqparse, marshal_with -from werkzeug.exceptions import NotFound, Forbidden - from controllers.console import api from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from fields.app_fields import app_site_fields -from libs.helper import supported_language from extensions.ext_database import db +from fields.app_fields import app_site_fields +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from libs.helper import supported_language +from libs.login import login_required from models.model import Site +from werkzeug.exceptions import Forbidden, NotFound def parse_app_site_args(): diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 1a9c50db73..f2c1726433 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -1,19 +1,18 @@ # -*- coding:utf-8 -*- -from decimal import Decimal from datetime import datetime +from decimal import Decimal import pytz -from flask import jsonify -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource, reqparse - from controllers.console import api from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from libs.helper import datetime_string from extensions.ext_database import db +from flask import jsonify +from flask_login import current_user +from flask_restful import Resource, reqparse +from libs.helper import datetime_string +from libs.login import login_required class DailyConversationStatistic(Resource): diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index c0209a05cd..7dc49af6bc 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -2,13 +2,12 @@ import base64 import secrets from datetime import datetime -from flask_restful import Resource, reqparse - from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db +from flask_restful import Resource, reqparse from libs.helper import email, str_len, supported_language, timezone -from libs.password import valid_password, hash_password +from libs.password import hash_password, valid_password from models.account import AccountStatus, Tenant from services.account_service import RegisterService diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index dec5a464c1..ff9084db30 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -1,15 +1,14 @@ import logging import requests -from flask import request, redirect, current_app +from controllers.console import api +from flask import current_app, redirect, request from flask_login import current_user - from flask_restful import Resource -from werkzeug.exceptions import Forbidden - from libs.login import login_required from libs.oauth_data_source import NotionOAuth -from controllers.console import api +from werkzeug.exceptions import Forbidden + from ..setup import setup_required from ..wraps import account_initialization_required diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 08d0a94b5b..a7c06f481d 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,12 +1,11 @@ # -*- coding:utf-8 -*- import flask import flask_login -from flask import request, current_app -from flask_restful import Resource, reqparse - import services from controllers.console import api from controllers.console.setup import setup_required +from flask import current_app, request +from flask_restful import Resource, reqparse from libs.helper import email from libs.password import valid_password from services.account_service import AccountService, TenantService diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index b1f8967119..9da78092d4 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -3,13 +3,13 @@ from datetime import datetime from typing import Optional import requests -from flask import request, redirect, current_app -from flask_restful import Resource - -from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth from extensions.ext_database import db +from flask import current_app, redirect, request +from flask_restful import Resource +from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models.account import Account, AccountStatus from services.account_service import AccountService, RegisterService + from .. import api diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 83741ce2ed..2053b08f07 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,10 +1,8 @@ -from flask_restful import Resource, reqparse -from flask_login import current_user - from controllers.console import api from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required -from controllers.console.wraps import only_edition_cloud +from controllers.console.wraps import account_initialization_required, only_edition_cloud +from flask_login import current_user +from flask_restful import Resource, reqparse from libs.login import login_required from services.billing_service import BillingService diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 9422526556..a9ecd3d27d 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,23 +1,22 @@ import datetime import json -from flask import request -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource, marshal_with, reqparse -from werkzeug.exceptions import NotFound - from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.data_loader.loader.notion import NotionLoader from core.indexing_runner import IndexingRunner from extensions.ext_database import db -from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields +from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from libs.login import login_required from models.dataset import Document from models.source import DataSourceBinding from services.dataset_service import DatasetService, DocumentService from tasks.document_indexing_sync_task import document_indexing_sync_task +from werkzeug.exceptions import NotFound class DataSourceApi(Resource): diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ac79374660..7be8e87ce0 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,29 +1,28 @@ # -*- coding:utf-8 -*- import flask_restful -from flask import request, current_app -from flask_login import current_user - -from controllers.console.apikey import api_key_list, api_key_fields -from core.model_runtime.entities.model_entities import ModelType -from core.provider_manager import ProviderManager -from libs.login import login_required -from flask_restful import Resource, reqparse, marshal, marshal_with -from werkzeug.exceptions import NotFound, Forbidden import services from controllers.console import api +from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.indexing_runner import IndexingRunner from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.indexing_runner import IndexingRunner +from core.model_runtime.entities.model_entities import ModelType +from core.provider_manager import ProviderManager +from extensions.ext_database import db from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.document_fields import document_status_fields -from extensions.ext_database import db -from models.dataset import DocumentSegment, Document -from models.model import UploadFile, ApiToken +from flask import current_app, request +from flask_login import current_user +from flask_restful import Resource, marshal, marshal_with, reqparse +from libs.login import login_required +from models.dataset import Document, DocumentSegment +from models.model import ApiToken, UploadFile from services.dataset_service import DatasetService, DocumentService +from werkzeug.exceptions import Forbidden, NotFound def _validate_name(name): diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 90af2bce05..b3830dcb75 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -2,38 +2,35 @@ from datetime import datetime from typing import List -from flask import request -from flask_login import current_user - +import services +from controllers.console import api +from controllers.console.app.error import (ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, + ProviderQuotaExceededError) +from controllers.console.datasets.error import (ArchivedDocumentImmutableError, DocumentAlreadyFinishedError, + DocumentIndexingError, InvalidActionError, InvalidMetadataError) +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from core.errors.error import (LLMBadRequestError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError, + QuotaExceededError) +from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError -from libs.login import login_required -from flask_restful import Resource, fields, marshal, marshal_with, reqparse -from sqlalchemy import desc, asc -from werkzeug.exceptions import NotFound, Forbidden - -import services -from controllers.console import api -from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \ - ProviderModelCurrentlyNotSupportError -from controllers.console.datasets.error import DocumentAlreadyFinishedError, InvalidActionError, DocumentIndexingError, \ - InvalidMetadataError, ArchivedDocumentImmutableError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.indexing_runner import IndexingRunner -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ - LLMBadRequestError -from extensions.ext_redis import redis_client -from fields.document_fields import document_with_segments_fields, document_fields, \ - dataset_and_document_fields, document_status_fields from extensions.ext_database import db -from models.dataset import DatasetProcessRule, Dataset -from models.dataset import Document, DocumentSegment +from extensions.ext_redis import redis_client +from fields.document_fields import (dataset_and_document_fields, document_fields, document_status_fields, + document_with_segments_fields) +from flask import request +from flask_login import current_user +from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from libs.login import login_required +from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment from models.model import UploadFile -from services.dataset_service import DocumentService, DatasetService +from services.dataset_service import DatasetService, DocumentService +from sqlalchemy import asc, desc from tasks.add_document_to_index_task import add_document_to_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task +from werkzeug.exceptions import Forbidden, NotFound class DocumentResource(Resource): diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index f3e39813fe..befe7f30d4 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,11 +1,8 @@ # -*- coding:utf-8 -*- import uuid from datetime import datetime -from flask import request -from flask_login import current_user -from flask_restful import Resource, reqparse, marshal -from werkzeug.exceptions import NotFound, Forbidden +import pandas as pd import services from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError @@ -15,17 +12,19 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from libs.login import login_required from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import segment_fields +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal, reqparse +from libs.login import login_required from models.dataset import DocumentSegment - from services.dataset_service import DatasetService, DocumentService, SegmentService -from tasks.enable_segment_to_index_task import enable_segment_to_index_task -from tasks.disable_segment_from_index_task import disable_segment_from_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task -import pandas as pd +from tasks.disable_segment_from_index_task import disable_segment_from_index_task +from tasks.enable_segment_to_index_task import enable_segment_to_index_task +from werkzeug.exceptions import Forbidden, NotFound class DatasetDocumentSegmentListApi(Resource): diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index 19404af4a9..212fd78b37 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -1,18 +1,14 @@ -from flask import request, current_app -from flask_login import current_user - import services -from libs.login import login_required -from flask_restful import Resource, marshal_with - from controllers.console import api -from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ - UnsupportedFileTypeError - +from controllers.console.datasets.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError, + UnsupportedFileTypeError) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from fields.file_fields import upload_config_fields, file_fields - +from fields.file_fields import file_fields, upload_config_fields +from flask import current_app, request +from flask_login import current_user +from flask_restful import Resource, marshal_with +from libs.login import login_required from services.file_service import FileService PREVIEW_WORDS_LIMIT = 3000 diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 06f75e2706..a32a3217e5 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,24 +1,22 @@ import logging -from flask_login import current_user - -from core.model_runtime.errors.invoke import InvokeError -from libs.login import login_required -from flask_restful import Resource, reqparse, marshal -from werkzeug.exceptions import InternalServerError, NotFound, Forbidden - import services from controllers.console import api -from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \ - ProviderModelCurrentlyNotSupportError, CompletionRequestError -from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError +from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, ProviderQuotaExceededError) +from controllers.console.datasets.error import DatasetNotInitializedError, HighQualityDatasetOnlyError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ - LLMBadRequestError +from core.errors.error import (LLMBadRequestError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError, + QuotaExceededError) +from core.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields +from flask_login import current_user +from flask_restful import Resource, marshal, reqparse +from libs.login import login_required from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound class HitTestingApi(Resource): diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index a01b7b40d3..00ae66e663 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,22 +1,21 @@ # -*- coding:utf-8 -*- import logging -from flask import request -from werkzeug.exceptions import InternalServerError - import services from controllers.console import api -from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \ - ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \ - NoAudioUploadedError, AudioTooLargeError, \ - UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError +from controllers.console.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, + NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, UnsupportedAudioTypeError) from controllers.console.explore.wraps import InstalledAppResource -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from services.audio_service import AudioService -from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ - UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError +from flask import request from models.model import AppModelConfig +from services.audio_service import AudioService +from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) +from werkzeug.exceptions import InternalServerError class ChatAudioApi(InstalledAppResource): diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 5b690af052..6641fbc90a 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -4,24 +4,24 @@ import logging from datetime import datetime from typing import Generator, Union -from flask import Response, stream_with_context -from flask_login import current_user -from flask_restful import reqparse -from werkzeug.exceptions import InternalServerError, NotFound - import services from controllers.console import api -from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \ - ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError -from controllers.console.explore.error import NotCompletionAppError, NotChatAppError +from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, + ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, + ProviderQuotaExceededError) +from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db +from flask import Response, stream_with_context +from flask_login import current_user +from flask_restful import reqparse from libs.helper import uuid_value from services.completion_service import CompletionService +from werkzeug.exceptions import InternalServerError, NotFound # define completion api for user diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 9848175fb5..1b6b493671 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,17 +1,16 @@ # -*- coding:utf-8 -*- -from flask_login import current_user -from flask_restful import fields, reqparse, marshal_with -from flask_restful.inputs import int_range -from werkzeug.exceptions import NotFound - from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from flask_login import current_user +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value from services.conversation_service import ConversationService -from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError +from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService +from werkzeug.exceptions import NotFound class ConversationListApi(InstalledAppResource): diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index b1e30f4455..7bde88efbe 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,20 +1,18 @@ # -*- coding:utf-8 -*- from datetime import datetime -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource, reqparse, marshal_with, inputs -from sqlalchemy import and_ -from werkzeug.exceptions import NotFound, Forbidden, BadRequest - from controllers.console import api from controllers.console.explore.wraps import InstalledAppResource -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields +from flask_login import current_user +from flask_restful import Resource, inputs, marshal_with, reqparse +from libs.login import login_required from models.model import App, InstalledApp, RecommendedApp from services.account_service import TenantService -from controllers.console.wraps import cloud_edition_billing_resource_check +from sqlalchemy import and_ +from werkzeug.exceptions import BadRequest, Forbidden, NotFound class InstalledAppsListApi(Resource): diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index a521beae18..1f4178c17b 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -3,29 +3,29 @@ import json import logging from typing import Generator, Union -from flask import stream_with_context, Response -from flask_login import current_user -from flask_restful import reqparse, marshal_with -from flask_restful.inputs import int_range -from werkzeug.exceptions import NotFound, InternalServerError - import services from controllers.console import api -from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \ - ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError -from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \ - NotChatAppError +from controllers.console.app.error import (AppMoreLikeThisDisabledError, CompletionRequestError, + ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, + ProviderQuotaExceededError) +from controllers.console.explore.error import (AppSuggestedQuestionsAfterAnswerDisabledError, NotChatAppError, + NotCompletionAppError) from controllers.console.explore.wraps import InstalledAppResource from core.entities.application_entities import InvokeFrom -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields +from flask import Response, stream_with_context +from flask_login import current_user +from flask_restful import marshal_with, reqparse +from flask_restful.inputs import int_range from libs.helper import uuid_value from services.completion_service import CompletionService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService +from werkzeug.exceptions import InternalServerError, NotFound class MessageListApi(InstalledAppResource): diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 435ebaee7c..7be6966129 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,10 +1,8 @@ # -*- coding:utf-8 -*- -from flask_restful import marshal_with, fields -from flask import current_app - from controllers.console import api from controllers.console.explore.wraps import InstalledAppResource - +from flask import current_app +from flask_restful import fields, marshal_with from models.model import InstalledApp diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index f2d8b89803..71d77ee74a 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,15 +1,14 @@ # -*- coding:utf-8 -*- -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource, fields, marshal_with -from sqlalchemy import and_ - from controllers.console import api from controllers.console.app.error import AppNotFoundError from controllers.console.wraps import account_initialization_required from extensions.ext_database import db +from flask_login import current_user +from flask_restful import Resource, fields, marshal_with +from libs.login import login_required from models.model import App, InstalledApp, RecommendedApp from services.account_service import TenantService +from sqlalchemy import and_ app_fields = { 'id': fields.String, diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 7977d0ebaf..9d355df355 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,15 +1,14 @@ -from flask_login import current_user -from flask_restful import reqparse, marshal_with, fields -from flask_restful.inputs import int_range -from werkzeug.exceptions import NotFound - from controllers.console import api from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource -from libs.helper import uuid_value, TimestampField +from fields.conversation_fields import message_file_fields +from flask_login import current_user +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -from fields.conversation_fields import message_file_fields +from werkzeug.exceptions import NotFound feedback_fields = { 'rating': fields.String diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 5e65c94d0d..d02b869bf7 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,13 +1,12 @@ -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource from functools import wraps -from werkzeug.exceptions import NotFound - from controllers.console.wraps import account_initialization_required from extensions.ext_database import db +from flask_login import current_user +from flask_restful import Resource +from libs.login import login_required from models.model import InstalledApp +from werkzeug.exceptions import NotFound def installed_app_required(view=None): diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 50b33e39ad..78374cf2a9 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,14 +1,13 @@ -from flask_restful import Resource, reqparse, marshal_with -from flask_login import current_user - from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from fields.api_based_extension_fields import api_based_extension_fields +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse from libs.login import login_required from models.api_based_extension import APIBasedExtension -from fields.api_based_extension_fields import api_based_extension_fields -from services.code_based_extension_service import CodeBasedExtensionService from services.api_based_extension_service import APIBasedExtensionService +from services.code_based_extension_service import CodeBasedExtensionService class CodeBasedExtensionAPI(Resource): diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 0d1d61ad00..f1e6286b6d 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,8 +1,8 @@ -from flask_restful import Resource from flask_login import current_user +from flask_restful import Resource +from services.feature_service import FeatureService from . import api -from services.feature_service import FeatureService class FeatureApi(Resource): diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index fd9273e342..ad37561e42 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,15 +1,13 @@ # -*- coding:utf-8 -*- from functools import wraps -from flask import request, current_app -from flask_restful import Resource, reqparse - from extensions.ext_database import db -from models.model import DifySetup -from services.account_service import AccountService, TenantService, RegisterService - +from flask import current_app, request +from flask_restful import Resource, reqparse from libs.helper import email, str_len from libs.password import valid_password +from models.model import DifySetup +from services.account_service import AccountService, RegisterService, TenantService from . import api from .error import AlreadySetupError, NotSetupError diff --git a/api/controllers/console/universal_chat/audio.py b/api/controllers/console/universal_chat/audio.py index 2f0b2568ac..2566448d49 100644 --- a/api/controllers/console/universal_chat/audio.py +++ b/api/controllers/console/universal_chat/audio.py @@ -1,22 +1,21 @@ # -*- coding:utf-8 -*- import logging -from flask import request -from werkzeug.exceptions import InternalServerError - import services from controllers.console import api -from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \ - ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \ - NoAudioUploadedError, AudioTooLargeError, \ - UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError +from controllers.console.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, + NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, UnsupportedAudioTypeError) from controllers.console.universal_chat.wraps import UniversalChatResource -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from services.audio_service import AudioService -from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ - UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError +from flask import request from models.model import AppModelConfig +from services.audio_service import AudioService +from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) +from werkzeug.exceptions import InternalServerError class UniversalChatAudioApi(UniversalChatResource): diff --git a/api/controllers/console/universal_chat/chat.py b/api/controllers/console/universal_chat/chat.py index cf900efced..7530af05e7 100644 --- a/api/controllers/console/universal_chat/chat.py +++ b/api/controllers/console/universal_chat/chat.py @@ -2,22 +2,22 @@ import json import logging from typing import Generator, Union -from flask import Response, stream_with_context -from flask_login import current_user -from flask_restful import reqparse -from werkzeug.exceptions import InternalServerError, NotFound - import services from controllers.console import api -from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \ - ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError +from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, + ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, + ProviderQuotaExceededError) from controllers.console.universal_chat.wraps import UniversalChatResource from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from flask import Response, stream_with_context +from flask_login import current_user +from flask_restful import reqparse from libs.helper import uuid_value from services.completion_service import CompletionService +from werkzeug.exceptions import InternalServerError, NotFound class UniversalChatApi(UniversalChatResource): diff --git a/api/controllers/console/universal_chat/conversation.py b/api/controllers/console/universal_chat/conversation.py index 046a6c1341..af141b67a5 100644 --- a/api/controllers/console/universal_chat/conversation.py +++ b/api/controllers/console/universal_chat/conversation.py @@ -1,17 +1,16 @@ # -*- coding:utf-8 -*- -from flask_login import current_user -from flask_restful import fields, reqparse, marshal_with -from flask_restful.inputs import int_range -from werkzeug.exceptions import NotFound - from controllers.console import api from controllers.console.universal_chat.wraps import UniversalChatResource -from fields.conversation_fields import conversation_with_model_config_infinite_scroll_pagination_fields, \ - conversation_with_model_config_fields +from fields.conversation_fields import (conversation_with_model_config_fields, + conversation_with_model_config_infinite_scroll_pagination_fields) +from flask_login import current_user +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value from services.conversation_service import ConversationService -from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError +from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService +from werkzeug.exceptions import NotFound class UniversalChatConversationListApi(UniversalChatResource): diff --git a/api/controllers/console/universal_chat/message.py b/api/controllers/console/universal_chat/message.py index 6421d85a1f..503615d751 100644 --- a/api/controllers/console/universal_chat/message.py +++ b/api/controllers/console/universal_chat/message.py @@ -1,23 +1,22 @@ # -*- coding:utf-8 -*- import logging -from flask_login import current_user -from flask_restful import reqparse, fields, marshal_with -from flask_restful.inputs import int_range -from werkzeug.exceptions import NotFound, InternalServerError - import services from controllers.console import api -from controllers.console.app.error import ProviderNotInitializeError, \ - ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError +from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, ProviderQuotaExceededError) from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError from controllers.console.universal_chat.wraps import UniversalChatResource -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from libs.helper import uuid_value, TimestampField +from flask_login import current_user +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from libs.helper import TimestampField, uuid_value from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService +from werkzeug.exceptions import InternalServerError, NotFound class UniversalChatMessageListApi(UniversalChatResource): diff --git a/api/controllers/console/universal_chat/parameter.py b/api/controllers/console/universal_chat/parameter.py index 31bb9f2790..dca86e39c1 100644 --- a/api/controllers/console/universal_chat/parameter.py +++ b/api/controllers/console/universal_chat/parameter.py @@ -1,11 +1,9 @@ # -*- coding:utf-8 -*- import json -from flask_restful import marshal_with, fields - from controllers.console import api from controllers.console.universal_chat.wraps import UniversalChatResource - +from flask_restful import fields, marshal_with from models.model import App diff --git a/api/controllers/console/universal_chat/wraps.py b/api/controllers/console/universal_chat/wraps.py index 1fd1747848..3e5600639e 100644 --- a/api/controllers/console/universal_chat/wraps.py +++ b/api/controllers/console/universal_chat/wraps.py @@ -1,12 +1,12 @@ import json from functools import wraps -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db +from flask_login import current_user +from flask_restful import Resource +from libs.login import login_required from models.model import App, AppModelConfig diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 2561b06b5d..ba49506618 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -5,7 +5,7 @@ import logging import requests from flask import current_app -from flask_restful import reqparse, Resource +from flask_restful import Resource, reqparse from werkzeug.exceptions import InternalServerError from . import api diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 8df13d9d02..43c15e1a95 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -2,21 +2,20 @@ from datetime import datetime import pytz -from flask import current_app, request -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource, reqparse, fields, marshal_with - -from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError from controllers.console import api from controllers.console.setup import setup_required -from controllers.console.workspace.error import AccountAlreadyInitedError, InvalidInvitationCodeError, \ - RepeatPasswordNotMatchError, CurrentPasswordIncorrectError +from controllers.console.workspace.error import (AccountAlreadyInitedError, CurrentPasswordIncorrectError, + InvalidInvitationCodeError, RepeatPasswordNotMatchError) from controllers.console.wraps import account_initialization_required -from libs.helper import TimestampField, supported_language, timezone from extensions.ext_database import db -from models.account import InvitationCode, AccountIntegrate +from flask import current_app, request +from flask_login import current_user +from flask_restful import Resource, fields, marshal_with, reqparse +from libs.helper import TimestampField, supported_language, timezone +from libs.login import login_required +from models.account import AccountIntegrate, InvitationCode from services.account_service import AccountService +from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError account_fields = { 'id': fields.String, diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 104180e4a6..9c6745db88 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,17 +1,16 @@ # -*- coding:utf-8 -*- -from flask import current_app -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal - import services from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from libs.helper import TimestampField from extensions.ext_database import db +from flask import current_app +from flask_login import current_user +from flask_restful import Resource, abort, fields, marshal, marshal_with, reqparse +from libs.helper import TimestampField +from libs.login import login_required from models.account import Account, TenantAccountJoin -from services.account_service import TenantService, RegisterService +from services.account_service import RegisterService, TenantService account_fields = { 'id': fields.String, diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 4086fdf049..a78a253dd0 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,19 +1,18 @@ import io -from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse -from werkzeug.exceptions import Forbidden - from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder +from flask import send_file +from flask_login import current_user +from flask_restful import Resource, reqparse from libs.login import login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService +from werkzeug.exceptions import Forbidden class ModelProviderListApi(Resource): diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index b5cb97d547..305c9f09af 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,17 +1,16 @@ import logging -from flask_login import current_user -from flask_restful import reqparse, Resource -from werkzeug.exceptions import Forbidden - from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder +from flask_login import current_user +from flask_restful import Resource, reqparse from libs.login import login_required from services.model_provider_service import ModelProviderService +from werkzeug.exceptions import Forbidden class DefaultModelApi(Resource): diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 3cc9c14fb5..c6416e1d3b 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,17 +1,16 @@ import json -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource, abort, reqparse -from werkzeug.exceptions import Forbidden - from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.tool.provider.errors import ToolValidateFailedError from core.tool.provider.tool_provider_service import ToolProviderService from extensions.ext_database import db +from flask_login import current_user +from flask_restful import Resource, abort, reqparse +from libs.login import login_required from models.tool import ToolProvider, ToolProviderName +from werkzeug.exceptions import Forbidden class ToolProviderListApi(Resource): diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index de07ab62c7..8f00d76f7a 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,24 +1,24 @@ # -*- coding:utf-8 -*- import logging -from flask import request -from flask_login import current_user -from libs.login import login_required -from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs - +import services from controllers.console import api from controllers.console.admin import admin_required -from controllers.console.setup import setup_required +from controllers.console.datasets.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError, + UnsupportedFileTypeError) from controllers.console.error import AccountNotLinkTenantError +from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, UnsupportedFileTypeError -from libs.helper import TimestampField from extensions.ext_database import db +from flask import request +from flask_login import current_user +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from libs.helper import TimestampField +from libs.login import login_required from models.account import Tenant -import services from services.account_service import TenantService -from services.workspace_service import WorkspaceService from services.file_service import FileService +from services.workspace_service import WorkspaceService provider_fields = { 'provider_name': fields.String, diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index c19ef14708..049657fc85 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,10 +1,9 @@ # -*- coding:utf-8 -*- from functools import wraps -from flask import current_app, abort -from flask_login import current_user - from controllers.console.workspace.error import AccountNotInitializedError +from flask import abort, current_app +from flask_login import current_user from services.feature_service import FeatureService diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 83374497b7..ff8d54c726 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -1,6 +1,5 @@ # -*- coding:utf-8 -*- from flask import Blueprint - from libs.external_api import ExternalApi bp = Blueprint('files', __name__) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 089f48f1f0..4227f139dd 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,12 +1,11 @@ -from flask import request, Response -from flask_restful import Resource -from werkzeug.exceptions import NotFound - import services from controllers.files import api +from flask import Response, request +from flask_restful import Resource from libs.exception import BaseHTTPException -from services.file_service import FileService from services.account_service import TenantService +from services.file_service import FileService +from werkzeug.exceptions import NotFound class ImagePreviewApi(Resource): diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index b0ee669362..20d6703b1d 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -1,12 +1,10 @@ # -*- coding:utf-8 -*- from flask import Blueprint - from libs.external_api import ExternalApi bp = Blueprint('service_api', __name__, url_prefix='/v1') api = ExternalApi(bp) -from .app import completion, app, conversation, message, audio, file - -from .dataset import document, segment, dataset +from .app import app, audio, completion, conversation, file, message +from .dataset import dataset, document, segment diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 409709b812..2809a9135b 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,10 +1,8 @@ # -*- coding:utf-8 -*- -from flask_restful import fields, marshal_with -from flask import current_app - from controllers.service_api import api from controllers.service_api.wraps import AppApiResource - +from flask import current_app +from flask_restful import fields, marshal_with from models.model import App diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 466ead670f..17e9abdb55 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,20 +1,21 @@ import logging -from flask import request -from werkzeug.exceptions import InternalServerError - import services from controllers.service_api import api -from controllers.service_api.app.error import AppUnavailableError, ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \ - ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \ - ProviderNotSupportSpeechToTextError +from controllers.service_api.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, + NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, UnsupportedAudioTypeError) from controllers.service_api.wraps import AppApiResource -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from flask import request from models.model import App, AppModelConfig from services.audio_service import AudioService -from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ - UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError +from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) +from werkzeug.exceptions import InternalServerError + class AudioApi(AppApiResource): def post(self, app_model: App, end_user): diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 509fb38cd7..df5df90403 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,24 +1,23 @@ import json import logging -from typing import Union, Generator - -from flask import stream_with_context, Response -from flask_restful import reqparse -from werkzeug.exceptions import NotFound, InternalServerError +from typing import Generator, Union import services from controllers.service_api import api from controllers.service_api.app import create_or_update_end_user_for_user_id -from controllers.service_api.app.error import AppUnavailableError, ProviderNotInitializeError, NotChatAppError, \ - ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \ - ProviderModelCurrentlyNotSupportError +from controllers.service_api.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, + NotChatAppError, ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, ProviderQuotaExceededError) from controllers.service_api.wraps import AppApiResource from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from flask import Response, stream_with_context +from flask_restful import reqparse from libs.helper import uuid_value from services.completion_service import CompletionService +from werkzeug.exceptions import InternalServerError, NotFound class CompletionApi(AppApiResource): diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 3e9aa07da6..d2f11678b6 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,17 +1,16 @@ # -*- coding:utf-8 -*- -from flask import request -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from werkzeug.exceptions import NotFound - +import services from controllers.service_api import api from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import AppApiResource from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from flask import request +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value -import services from services.conversation_service import ConversationService +from werkzeug.exceptions import NotFound class ConversationApi(AppApiResource): diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index b2cb7a05f9..8e7984ced1 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,14 +1,13 @@ +import services +from controllers.service_api import api +from controllers.service_api.app import create_or_update_end_user_for_user_id +from controllers.service_api.app.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError, + UnsupportedFileTypeError) +from controllers.service_api.wraps import AppApiResource +from fields.file_fields import file_fields from flask import request from flask_restful import marshal_with - -from controllers.service_api import api -from controllers.service_api.wraps import AppApiResource -from controllers.service_api.app import create_or_update_end_user_for_user_id -from controllers.service_api.app.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ - UnsupportedFileTypeError -import services from services.file_service import FileService -from fields.file_fields import file_fields class FileApi(AppApiResource): diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 16106c340e..07c4318b84 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,18 +1,18 @@ # -*- coding:utf-8 -*- -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from werkzeug.exceptions import NotFound - import services from controllers.service_api import api from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import AppApiResource -from libs.helper import TimestampField, uuid_value -from services.message_service import MessageService from extensions.ext_database import db -from models.model import Message, EndUser from fields.conversation_fields import message_file_fields +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from libs.helper import TimestampField, uuid_value +from models.model import EndUser, Message +from services.message_service import MessageService +from werkzeug.exceptions import NotFound + class MessageListApi(AppApiResource): feedback_fields = { diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 813f929557..6028d7c341 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,13 +1,13 @@ -from flask import request -from flask_restful import reqparse, marshal import services.dataset_service from controllers.service_api import api from controllers.service_api.dataset.error import DatasetNameDuplicateError from controllers.service_api.wraps import DatasetApiResource from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager -from libs.login import current_user from fields.dataset_fields import dataset_detail_fields +from flask import request +from flask_restful import marshal, reqparse +from libs.login import current_user from services.dataset_service import DatasetService diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index e851799f75..d7694070f0 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,24 +1,23 @@ import json -from flask import request -from flask_restful import reqparse, marshal -from flask_login import current_user -from sqlalchemy import desc -from werkzeug.exceptions import NotFound - import services.dataset_service from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError -from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ - NoFileUploadedError, TooManyFilesError +from controllers.service_api.dataset.error import (ArchivedDocumentImmutableError, DocumentIndexingError, + NoFileUploadedError, TooManyFilesError) from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check -from libs.login import current_user from core.errors.error import ProviderTokenNotInitError from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields +from flask import request +from flask_login import current_user +from flask_restful import marshal, reqparse +from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment from services.dataset_service import DocumentService from services.file_service import FileService +from sqlalchemy import desc +from werkzeug.exceptions import NotFound class DocumentAddByTextApi(DatasetApiResource): diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 9940dba1e1..4cc313e042 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,16 +1,16 @@ -from flask_login import current_user -from flask_restful import reqparse, marshal -from werkzeug.exceptions import NotFound from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check -from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import segment_fields +from flask_login import current_user +from flask_restful import marshal, reqparse from models.dataset import Dataset, DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService +from werkzeug.exceptions import NotFound class SegmentApi(DatasetApiResource): diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 7320a6e614..16cc3679b0 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -2,16 +2,16 @@ from datetime import datetime from functools import wraps -from flask import request, current_app +from extensions.ext_database import db +from flask import current_app, request from flask_login import user_logged_in from flask_restful import Resource -from werkzeug.exceptions import NotFound, Unauthorized - from libs.login import _get_user -from extensions.ext_database import db -from models.account import Tenant, TenantAccountJoin, Account +from models.account import Account, Tenant, TenantAccountJoin from models.model import ApiToken, App from services.feature_service import FeatureService +from werkzeug.exceptions import NotFound, Unauthorized + def validate_app_token(view=None): def decorator(view): diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 3fba1869ce..27ea0cdb67 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -1,10 +1,9 @@ # -*- coding:utf-8 -*- from flask import Blueprint - from libs.external_api import ExternalApi bp = Blueprint('web', __name__, url_prefix='/api') api = ExternalApi(bp) -from . import completion, app, conversation, message, site, saved_message, audio, passport, file +from . import app, audio, completion, conversation, file, message, passport, saved_message, site diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 07f200ddc2..22b274c72d 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,10 +1,8 @@ # -*- coding:utf-8 -*- -from flask_restful import marshal_with, fields -from flask import current_app - from controllers.web import api from controllers.web.wraps import WebApiResource - +from flask import current_app +from flask_restful import fields, marshal_with from models.model import App diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 67825c2b3b..edbe9b71b8 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -1,21 +1,21 @@ # -*- coding:utf-8 -*- import logging -from flask import request -from werkzeug.exceptions import InternalServerError - import services from controllers.web import api -from controllers.web.error import AppUnavailableError, ProviderNotInitializeError, CompletionRequestError, \ - ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \ - UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError +from controllers.web.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, + NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, UnsupportedAudioTypeError) from controllers.web.wraps import WebApiResource -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from services.audio_service import AudioService -from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ - UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError +from flask import request from models.model import App, AppModelConfig +from services.audio_service import AudioService +from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) +from werkzeug.exceptions import InternalServerError class AudioApi(WebApiResource): diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 8a492f8ca2..411ab13b3d 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -3,22 +3,21 @@ import json import logging from typing import Generator, Union -from flask import Response, stream_with_context -from flask_restful import reqparse -from werkzeug.exceptions import InternalServerError, NotFound - import services from controllers.web import api -from controllers.web.error import AppUnavailableError, ConversationCompletedError, \ - ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \ - ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError +from controllers.web.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, + NotChatAppError, NotCompletionAppError, ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, ProviderQuotaExceededError) from controllers.web.wraps import WebApiResource from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from flask import Response, stream_with_context +from flask_restful import reqparse from libs.helper import uuid_value from services.completion_service import CompletionService +from werkzeug.exceptions import InternalServerError, NotFound # define completion api for user diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 2f6e79ba12..1f17f7883e 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,16 +1,15 @@ # -*- coding:utf-8 -*- -from flask_restful import fields, reqparse, marshal_with -from flask_restful.inputs import int_range -from werkzeug.exceptions import NotFound - from controllers.web import api from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value from services.conversation_service import ConversationService -from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError +from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService +from werkzeug.exceptions import NotFound class ConversationListApi(WebApiResource): diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py index 985e9c5b58..c43fe6fdf5 100644 --- a/api/controllers/web/file.py +++ b/api/controllers/web/file.py @@ -1,13 +1,11 @@ +import services +from controllers.web import api +from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError +from controllers.web.wraps import WebApiResource +from fields.file_fields import file_fields from flask import request from flask_restful import marshal_with - -from controllers.web import api -from controllers.web.wraps import WebApiResource -from controllers.web.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ - UnsupportedFileTypeError -import services from services.file_service import FileService -from fields.file_fields import file_fields class FileApi(WebApiResource): diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 9734aa177f..3651a71096 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -3,27 +3,27 @@ import json import logging from typing import Generator, Union -from flask import stream_with_context, Response -from flask_restful import reqparse, fields, marshal_with -from flask_restful.inputs import int_range -from werkzeug.exceptions import NotFound, InternalServerError - import services from controllers.web import api -from controllers.web.error import NotChatAppError, CompletionRequestError, ProviderNotInitializeError, \ - AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \ - ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError +from controllers.web.error import (AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, + CompletionRequestError, NotChatAppError, NotCompletionAppError, + ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, + ProviderQuotaExceededError) from controllers.web.wraps import WebApiResource from core.entities.application_entities import InvokeFrom -from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from libs.helper import uuid_value, TimestampField +from fields.conversation_fields import message_file_fields +from flask import Response, stream_with_context +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from libs.helper import TimestampField, uuid_value from services.completion_service import CompletionService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService -from fields.conversation_fields import message_file_fields +from werkzeug.exceptions import InternalServerError, NotFound class MessageListApi(WebApiResource): diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index a5d3e388ac..bc6cf6028b 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,12 +1,14 @@ # -*- coding:utf-8 -*- import uuid + from controllers.web import api -from flask_restful import Resource -from flask import request -from werkzeug.exceptions import Unauthorized, NotFound -from models.model import Site, EndUser, App from extensions.ext_database import db +from flask import request +from flask_restful import Resource from libs.passport import PassportService +from models.model import App, EndUser, Site +from werkzeug.exceptions import NotFound, Unauthorized + class PassportResource(Resource): """Base resource for passport.""" diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 888032cdee..b353b9682e 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,15 +1,13 @@ -from flask_restful import reqparse, marshal_with, fields -from flask_restful.inputs import int_range -from werkzeug.exceptions import NotFound - from controllers.web import api from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource -from libs.helper import uuid_value, TimestampField +from fields.conversation_fields import message_file_fields +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -from fields.conversation_fields import message_file_fields - +from werkzeug.exceptions import NotFound feedback_fields = { 'rating': fields.String diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0f63e6087b..9f1297a06c 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,15 +1,14 @@ # -*- coding:utf-8 -*- import os -from flask_restful import fields, marshal_with -from flask import current_app -from werkzeug.exceptions import Forbidden - from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db +from flask import current_app +from flask_restful import fields, marshal_with from models.model import Site from services.feature_service import FeatureService +from werkzeug.exceptions import Forbidden class AppSiteApi(WebApiResource): diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 28a46d329e..0803a3b5ea 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,13 +1,13 @@ # -*- coding:utf-8 -*- from functools import wraps +from extensions.ext_database import db from flask import request from flask_restful import Resource +from libs.passport import PassportService +from models.model import App, EndUser, Site from werkzeug.exceptions import NotFound, Unauthorized -from extensions.ext_database import db -from models.model import App, EndUser, Site -from libs.passport import PassportService def validate_jwt_token(view=None): def decorator(view): diff --git a/api/core/agent/agent/agent_llm_callback.py b/api/core/agent/agent/agent_llm_callback.py index 04b9bab141..8331731200 100644 --- a/api/core/agent/agent/agent_llm_callback.py +++ b/api/core/agent/agent/agent_llm_callback.py @@ -1,10 +1,10 @@ import logging -from typing import Optional, List +from typing import List, Optional from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult -from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) diff --git a/api/core/agent/agent/calc_token_mixin.py b/api/core/agent/agent/calc_token_mixin.py index d8cdf9fe2b..1ca6c49812 100644 --- a/api/core/agent/agent/calc_token_mixin.py +++ b/api/core/agent/agent/calc_token_mixin.py @@ -1,12 +1,11 @@ from typing import List, cast -from langchain.schema import BaseMessage - from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from langchain.schema import BaseMessage class CalcTokenMixin: diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index 9b622a8689..c13641b84d 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -1,20 +1,19 @@ -from typing import Tuple, List, Any, Union, Sequence, Optional, cast +from typing import Any, List, Optional, Sequence, Tuple, Union, cast -from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent +from core.entities.application_entities import ModelConfigEntity +from core.entities.message_entities import lc_messages_to_prompt_messages +from core.model_manager import ModelInstance +from core.model_runtime.entities.message_entities import PromptMessageTool +from core.third_party.langchain.llms.fake import FakeLLM +from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage +from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage from langchain.tools import BaseTool from pydantic import root_validator -from core.entities.application_entities import ModelConfigEntity -from core.model_manager import ModelInstance -from core.entities.message_entities import lc_messages_to_prompt_messages -from core.model_runtime.entities.message_entities import PromptMessageTool -from core.third_party.langchain.llms.fake import FakeLLM - class MultiDatasetRouterAgent(OpenAIFunctionsAgent): """ diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py index 181208eb6a..e17282a293 100644 --- a/api/core/agent/agent/openai_function_call.py +++ b/api/core/agent/agent/openai_function_call.py @@ -1,28 +1,26 @@ -from typing import List, Tuple, Any, Union, Sequence, Optional, cast +from typing import Any, List, Optional, Sequence, Tuple, Union, cast -from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent -from langchain.agents.openai_functions_agent.base import _parse_ai_message, \ - _format_intermediate_steps +from core.agent.agent.agent_llm_callback import AgentLLMCallback +from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError +from core.chain.llm_chain import LLMChain +from core.entities.application_entities import ModelConfigEntity +from core.entities.message_entities import lc_messages_to_prompt_messages +from core.model_manager import ModelInstance +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.third_party.langchain.llms.fake import FakeLLM +from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent +from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken from langchain.memory.prompt import SUMMARY_PROMPT from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \ - get_buffer_string +from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage, HumanMessage, SystemMessage, + get_buffer_string) from langchain.tools import BaseTool from pydantic import root_validator -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin -from core.chain.llm_chain import LLMChain -from core.entities.application_entities import ModelConfigEntity -from core.model_manager import ModelInstance -from core.entities.message_entities import lc_messages_to_prompt_messages -from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.third_party.langchain.llms.fake import FakeLLM - class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): moving_summary_buffer: str = "" diff --git a/api/core/agent/agent/output_parser/structured_chat.py b/api/core/agent/agent/output_parser/structured_chat.py index f0332a007a..c2d748d8f6 100644 --- a/api/core/agent/agent/output_parser/structured_chat.py +++ b/api/core/agent/agent/output_parser/structured_chat.py @@ -2,8 +2,8 @@ import json import re from typing import Union -from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser, \ - logger +from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser +from langchain.agents.structured_chat.output_parser import logger from langchain.schema import AgentAction, AgentFinish, OutputParserException diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/agent/agent/structed_multi_dataset_router_agent.py index bc35ef0371..c8e6a84b09 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/agent/agent/structed_multi_dataset_router_agent.py @@ -1,18 +1,17 @@ import re -from typing import List, Tuple, Any, Union, Sequence, Optional, cast - -from langchain import BasePromptTemplate, PromptTemplate -from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent -from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate -from langchain.schema import AgentAction, AgentFinish, OutputParserException -from langchain.tools import BaseTool -from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX +from typing import Any, List, Optional, Sequence, Tuple, Union, cast from core.chain.llm_chain import LLMChain from core.entities.application_entities import ModelConfigEntity +from langchain import BasePromptTemplate, PromptTemplate +from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent +from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE +from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import Callbacks +from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate +from langchain.schema import AgentAction, AgentFinish, OutputParserException +from langchain.tools import BaseTool FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py index bbce1ca440..af0130b314 100644 --- a/api/core/agent/agent/structured_chat.py +++ b/api/core/agent/agent/structured_chat.py @@ -1,23 +1,22 @@ import re -from typing import List, Tuple, Any, Union, Sequence, Optional, cast - -from langchain import BasePromptTemplate, PromptTemplate -from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent -from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.memory.prompt import SUMMARY_PROMPT -from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate -from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \ - get_buffer_string -from langchain.tools import BaseTool -from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX +from typing import Any, List, Optional, Sequence, Tuple, Union, cast from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError from core.chain.llm_chain import LLMChain from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages +from langchain import BasePromptTemplate, PromptTemplate +from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent +from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE +from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import Callbacks +from langchain.memory.prompt import SUMMARY_PROMPT +from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate +from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage, HumanMessage, OutputParserException, + get_buffer_string) +from langchain.tools import BaseTool FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. diff --git a/api/core/agent/agent_executor.py b/api/core/agent/agent_executor.py index 52cc424ffb..4aa48337e5 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/agent/agent_executor.py @@ -1,11 +1,6 @@ import enum import logging -from typing import Union, Optional - -from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent -from langchain.callbacks.manager import Callbacks -from langchain.tools import BaseTool -from pydantic import BaseModel, Extra +from typing import Optional, Union from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent @@ -13,8 +8,6 @@ from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionC from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent -from langchain.agents import AgentExecutor as LCAgentExecutor - from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import prompt_messages_to_lc_messages from core.helper import moderation @@ -22,6 +15,11 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool +from langchain.agents import AgentExecutor as LCAgentExecutor +from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent +from langchain.callbacks.manager import Callbacks +from langchain.tools import BaseTool +from pydantic import BaseModel, Extra class PlanningStrategy(str, enum.Enum): diff --git a/api/core/app_runner/agent_app_runner.py b/api/core/app_runner/agent_app_runner.py index 8f951b7855..cc375056ce 100644 --- a/api/core/app_runner/agent_app_runner.py +++ b/api/core/app_runner/agent_app_runner.py @@ -4,16 +4,16 @@ from typing import cast from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.app_runner.app_runner import AppRunner -from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler -from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity from core.application_queue_manager import ApplicationQueueManager +from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler +from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, PromptTemplateEntity from core.features.agent_runner import AgentRunnerFeature from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db -from models.model import Conversation, Message, App, MessageChain, MessageAgentThought +from models.model import App, Conversation, Message, MessageAgentThought, MessageChain logger = logging.getLogger(__name__) diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index fe3c08f03f..a2edbfc3ab 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -1,12 +1,12 @@ import time -from typing import cast, Optional, List, Tuple, Generator, Union +from typing import Generator, List, Optional, Tuple, Union, cast from core.application_queue_manager import ApplicationQueueManager, PublishFrom -from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity +from core.entities.application_entities import AppOrchestrationConfigEntity, ModelConfigEntity, PromptTemplateEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import PromptMessage, AssistantPromptMessage +from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index a77baaa495..1190b6a653 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -1,11 +1,11 @@ import logging -from typing import Tuple, Optional +from typing import Optional, Tuple from core.app_runner.app_runner import AppRunner -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \ - AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.entities.application_entities import (ApplicationGenerateEntity, AppOrchestrationConfigEntity, DatasetEntity, + ExternalDataVariableEntity, InvokeFrom, ModelConfigEntity) from core.features.annotation_reply import AnnotationReplyFeature from core.features.dataset_retrieval import DatasetRetrievalFeature from core.features.external_data_fetch import ExternalDataFetchFeature @@ -17,7 +17,7 @@ from core.model_runtime.entities.message_entities import PromptMessage from core.moderation.base import ModerationException from core.prompt.prompt_transform import AppMode from extensions.ext_database import db -from models.model import Conversation, Message, App, MessageAnnotation +from models.model import App, Conversation, Message, MessageAnnotation logger = logging.getLogger(__name__) diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 69689fe167..d9f057fbf8 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -1,25 +1,25 @@ import json import logging import time -from typing import Union, Generator, cast, Optional +from typing import Generator, Optional, Union, cast -from pydantic import BaseModel - -from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule -from core.entities.application_entities import ApplicationGenerateEntity +from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler from core.application_queue_manager import ApplicationQueueManager, PublishFrom -from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \ - QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \ - AnnotationReplyEvent -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, \ - TextPromptMessageContent, PromptMessageContentType, ImagePromptMessageContent, PromptMessage -from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError +from core.entities.application_entities import ApplicationGenerateEntity +from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent, + QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent, + QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent) +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, + PromptMessage, PromptMessageContentType, PromptMessageRole, + TextPromptMessageContent) +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_template import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db -from models.model import Message, Conversation, MessageAgentThought +from models.model import Conversation, Message, MessageAgentThought +from pydantic import BaseModel from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) diff --git a/api/core/app_runner/moderation_handler.py b/api/core/app_runner/moderation_handler.py index 2917da6f29..ecb32ee437 100644 --- a/api/core/app_runner/moderation_handler.py +++ b/api/core/app_runner/moderation_handler.py @@ -1,14 +1,13 @@ import logging import threading import time -from typing import Any, Optional, Dict - -from flask import current_app, Flask -from pydantic import BaseModel +from typing import Any, Dict, Optional from core.application_queue_manager import PublishFrom from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.factory import ModerationFactory +from flask import Flask, current_app +from pydantic import BaseModel logger = logging.getLogger(__name__) diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 88500f3a47..7a0bed3ded 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -2,31 +2,32 @@ import json import logging import threading import uuid -from typing import cast, Optional, Any, Union, Generator, Tuple - -from flask import Flask, current_app -from pydantic import ValidationError +from typing import Any, Generator, Optional, Tuple, Union, cast from core.app_runner.agent_app_runner import AgentApplicationRunner from core.app_runner.basic_app_runner import BasicApplicationRunner from core.app_runner.generate_task_pipeline import GenerateTaskPipeline -from core.entities.application_entities import ApplicationGenerateEntity, AppOrchestrationConfigEntity, \ - ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \ - AdvancedCompletionPromptTemplateEntity, ExternalDataVariableEntity, DatasetEntity, DatasetRetrieveConfigEntity, \ - AgentEntity, AgentToolEntity, FileUploadEntity, SensitiveWordAvoidanceEntity, InvokeFrom +from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom +from core.entities.application_entities import (AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, AgentEntity, AgentToolEntity, + ApplicationGenerateEntity, AppOrchestrationConfigEntity, DatasetEntity, + DatasetRetrieveConfigEntity, ExternalDataVariableEntity, + FileUploadEntity, InvokeFrom, ModelConfigEntity, PromptTemplateEntity, + SensitiveWordAvoidanceEntity) from core.entities.model_entities import ModelStatus +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file.file_obj import FileObj -from core.errors.error import QuotaExceededError, ProviderTokenNotInitError, ModelCurrentlyNotSupportError from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_template import PromptTemplateParser from core.provider_manager import ProviderManager -from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom from extensions.ext_database import db +from flask import Flask, current_app from models.account import Account -from models.model import EndUser, Conversation, Message, MessageFile, App +from models.model import App, Conversation, EndUser, Message, MessageFile +from pydantic import ValidationError logger = logging.getLogger(__name__) diff --git a/api/core/application_queue_manager.py b/api/core/application_queue_manager.py index 678d08d772..09b92c5f84 100644 --- a/api/core/application_queue_manager.py +++ b/api/core/application_queue_manager.py @@ -1,17 +1,17 @@ import queue import time from enum import Enum -from typing import Generator, Any - -from sqlalchemy.orm import DeclarativeMeta +from typing import Any, Generator from core.entities.application_entities import InvokeFrom -from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \ - QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \ - QueueMessageEvent, QueueMessage, AnnotationReplyEvent +from core.entities.queue_entities import (AnnotationReplyEvent, AppQueueEvent, QueueAgentThoughtEvent, QueueErrorEvent, + QueueMessage, QueueMessageEndEvent, QueueMessageEvent, + QueueMessageReplaceEvent, QueuePingEvent, QueueRetrieverResourcesEvent, + QueueStopEvent) from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from extensions.ext_redis import redis_client from models.model import MessageAgentThought +from sqlalchemy.orm import DeclarativeMeta class PublishFrom(Enum): diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index 1c9d3d7139..edee77e25f 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -1,21 +1,19 @@ import json import logging import time - -from typing import Any, Dict, List, Union, Optional, cast - -from langchain.agents import openai_functions_agent, openai_functions_multi_agent -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage +from typing import Any, Dict, List, Optional, Union, cast from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.entity.agent_loop import AgentLoop from core.entities.application_entities import ModelConfigEntity from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessage +from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db -from models.model import MessageChain, MessageAgentThought, Message +from langchain.agents import openai_functions_agent, openai_functions_multi_agent +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, BaseMessage, ChatGeneration, LLMResult +from models.model import Message, MessageAgentThought, MessageChain class AgentLoopGatherCallbackHandler(BaseCallbackHandler): diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index a7dbbac393..9947028806 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,11 +1,10 @@ from typing import List, Union -from langchain.schema import Document - from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import InvokeFrom from extensions.ext_database import db -from models.dataset import DocumentSegment, DatasetQuery +from langchain.schema import Document +from models.dataset import DatasetQuery, DocumentSegment from models.model import DatasetRetrieverResource diff --git a/api/core/callback_handler/std_out_callback_handler.py b/api/core/callback_handler/std_out_callback_handler.py index 750d9d7399..9f586d2c9b 100644 --- a/api/core/callback_handler/std_out_callback_handler.py +++ b/api/core/callback_handler/std_out_callback_handler.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.input import print_text -from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage +from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult class DifyStdOutCallbackHandler(BaseCallbackHandler): diff --git a/api/core/chain/llm_chain.py b/api/core/chain/llm_chain.py index 4939ad9b13..20b71f2f64 100644 --- a/api/core/chain/llm_chain.py +++ b/api/core/chain/llm_chain.py @@ -1,15 +1,14 @@ -from typing import List, Dict, Any, Optional - -from langchain import LLMChain as LCLLMChain -from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.schema import LLMResult, Generation -from langchain.schema.language_model import BaseLanguageModel +from typing import Any, Dict, List, Optional from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.entities.application_entities import ModelConfigEntity -from core.model_manager import ModelInstance from core.entities.message_entities import lc_messages_to_prompt_messages +from core.model_manager import ModelInstance from core.third_party.langchain.llms.fake import FakeLLM +from langchain import LLMChain as LCLLMChain +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.schema import Generation, LLMResult +from langchain.schema.language_model import BaseLanguageModel class LLMChain(LCLLMChain): diff --git a/api/core/data_loader/file_extractor.py b/api/core/data_loader/file_extractor.py index 00f0e607c8..fe445de93c 100644 --- a/api/core/data_loader/file_extractor.py +++ b/api/core/data_loader/file_extractor.py @@ -1,12 +1,8 @@ import tempfile from pathlib import Path -from typing import List, Union, Optional +from typing import List, Optional, Union import requests -from flask import current_app -from langchain.document_loaders import TextLoader, Docx2txtLoader -from langchain.schema import Document - from core.data_loader.loader.csv_loader import CSVLoader from core.data_loader.loader.excel import ExcelLoader from core.data_loader.loader.html import HTMLLoader @@ -20,6 +16,9 @@ from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredP from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader from extensions.ext_storage import storage +from flask import current_app +from langchain.document_loaders import Docx2txtLoader, TextLoader +from langchain.schema import Document from models.model import UploadFile SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] diff --git a/api/core/data_loader/loader/csv_loader.py b/api/core/data_loader/loader/csv_loader.py index 0789121cf1..a4d4ed2b39 100644 --- a/api/core/data_loader/loader/csv_loader.py +++ b/api/core/data_loader/loader/csv_loader.py @@ -1,6 +1,6 @@ -import logging import csv -from typing import Optional, Dict, List +import logging +from typing import Dict, List, Optional from langchain.document_loaders import CSVLoader as LCCSVLoader from langchain.document_loaders.helpers import detect_file_encodings diff --git a/api/core/data_loader/loader/markdown.py b/api/core/data_loader/loader/markdown.py index 4e6c0d5637..545c6b10ed 100644 --- a/api/core/data_loader/loader/markdown.py +++ b/api/core/data_loader/loader/markdown.py @@ -1,6 +1,6 @@ import logging import re -from typing import Optional, List, Tuple, cast +from typing import List, Optional, Tuple, cast from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.helpers import detect_file_encodings diff --git a/api/core/data_loader/loader/notion.py b/api/core/data_loader/loader/notion.py index 2162df83e1..914c04d5c0 100644 --- a/api/core/data_loader/loader/notion.py +++ b/api/core/data_loader/loader/notion.py @@ -1,13 +1,12 @@ import json import logging -from typing import List, Dict, Any, Optional +from typing import Any, Dict, List, Optional import requests +from extensions.ext_database import db from flask import current_app from langchain.document_loaders.base import BaseLoader from langchain.schema import Document - -from extensions.ext_database import db from models.dataset import Document as DocumentModel from models.source import DataSourceBinding diff --git a/api/core/data_loader/loader/pdf.py b/api/core/data_loader/loader/pdf.py index 881d0026b5..8b08393d91 100644 --- a/api/core/data_loader/loader/pdf.py +++ b/api/core/data_loader/loader/pdf.py @@ -1,11 +1,10 @@ import logging from typing import List, Optional +from extensions.ext_storage import storage from langchain.document_loaders import PyPDFium2Loader from langchain.document_loaders.base import BaseLoader from langchain.schema import Document - -from extensions.ext_storage import storage from models.model import UploadFile logger = logging.getLogger(__name__) diff --git a/api/core/data_loader/loader/unstructured/unstructured_eml.py b/api/core/data_loader/loader/unstructured/unstructured_eml.py index fa097ac37b..26e0ce8cda 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_eml.py +++ b/api/core/data_loader/loader/unstructured/unstructured_eml.py @@ -1,6 +1,7 @@ -import logging import base64 +import logging from typing import List + from bs4 import BeautifulSoup from langchain.document_loaders.base import BaseLoader from langchain.schema import Document diff --git a/api/core/data_loader/loader/unstructured/unstructured_msg.py b/api/core/data_loader/loader/unstructured/unstructured_msg.py index 1e18dbcdf5..ba9f9c0340 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_msg.py +++ b/api/core/data_loader/loader/unstructured/unstructured_msg.py @@ -1,6 +1,6 @@ import logging import re -from typing import Optional, List, Tuple, cast +from typing import List, Optional, Tuple, cast from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.helpers import detect_file_encodings diff --git a/api/core/data_loader/loader/unstructured/unstructured_ppt.py b/api/core/data_loader/loader/unstructured/unstructured_ppt.py index 4560c262e9..1ad5dbc216 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_ppt.py +++ b/api/core/data_loader/loader/unstructured/unstructured_ppt.py @@ -1,6 +1,6 @@ import logging import re -from typing import Optional, List, Tuple, cast +from typing import List, Optional, Tuple, cast from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.helpers import detect_file_encodings diff --git a/api/core/data_loader/loader/unstructured/unstructured_pptx.py b/api/core/data_loader/loader/unstructured/unstructured_pptx.py index 7bb3c3af71..e5cff52fd9 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_pptx.py +++ b/api/core/data_loader/loader/unstructured/unstructured_pptx.py @@ -1,6 +1,6 @@ import logging import re -from typing import Optional, List, Tuple, cast +from typing import List, Optional, Tuple, cast from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.helpers import detect_file_encodings diff --git a/api/core/data_loader/loader/unstructured/unstructured_text.py b/api/core/data_loader/loader/unstructured/unstructured_text.py index f552f8bc86..779977cd16 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_text.py +++ b/api/core/data_loader/loader/unstructured/unstructured_text.py @@ -1,6 +1,6 @@ import logging import re -from typing import Optional, List, Tuple, cast +from typing import List, Optional, Tuple, cast from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.helpers import detect_file_encodings diff --git a/api/core/data_loader/loader/unstructured/unstructured_xml.py b/api/core/data_loader/loader/unstructured/unstructured_xml.py index 8c09512fb9..6d9d09c1a1 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_xml.py +++ b/api/core/data_loader/loader/unstructured/unstructured_xml.py @@ -1,6 +1,6 @@ import logging import re -from typing import Optional, List, Tuple, cast +from typing import List, Optional, Tuple, cast from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.helpers import detect_file_encodings diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 77a5dde9ed..49e87ec340 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -1,13 +1,12 @@ from typing import Any, Dict, Optional, Sequence, cast -from langchain.schema import Document -from sqlalchemy import func - from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db +from langchain.schema import Document from models.dataset import Dataset, DocumentSegment +from sqlalchemy import func class DatasetDocumentStore: diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 0022c8b141..285e2ba388 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -2,13 +2,12 @@ import logging from typing import List, Optional import numpy as np -from langchain.embeddings.base import Embeddings -from sqlalchemy.exc import IntegrityError - from core.model_manager import ModelInstance from extensions.ext_database import db +from langchain.embeddings.base import Embeddings from libs import helper from models.dataset import Embedding +from sqlalchemy.exc import IntegrityError logger = logging.getLogger(__name__) diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index 7e34eed51e..47a1ac6510 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -1,12 +1,11 @@ from enum import Enum -from typing import Optional, Any, cast - -from pydantic import BaseModel +from typing import Any, Optional, cast from core.entities.provider_configuration import ProviderModelBundle from core.file.file_obj import FileObj from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import AIModelEntity +from pydantic import BaseModel class ModelConfigEntity(BaseModel): diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py index d72f436aad..9b0b287f28 100644 --- a/api/core/entities/message_entities.py +++ b/api/core/entities/message_entities.py @@ -1,12 +1,12 @@ import enum from typing import Any, cast -from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, + PromptMessage, SystemPromptMessage, TextPromptMessageContent, + ToolPromptMessage, UserPromptMessage) +from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage from pydantic import BaseModel -from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage, TextPromptMessageContent, \ - ImagePromptMessageContent, AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage - class PromptMessageFileType(enum.Enum): IMAGE = 'image' diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 6b49ec9248..3888807227 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,11 +1,10 @@ from enum import Enum from typing import Optional -from pydantic import BaseModel - from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ProviderModel, ModelType -from core.model_runtime.entities.provider_entities import SimpleProviderEntity, ProviderEntity +from core.model_runtime.entities.model_entities import ModelType, ProviderModel +from core.model_runtime.entities.provider_entities import ProviderEntity, SimpleProviderEntity +from pydantic import BaseModel class ModelStatus(Enum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 3af0c2b4d9..cb31d01e99 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,25 +1,23 @@ import datetime import json import logging - from json import JSONDecodeError -from typing import Optional, List, Dict, Tuple, Iterator +from typing import Dict, Iterator, List, Optional, Tuple -from pydantic import BaseModel - -from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity -from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus +from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from core.model_runtime.entities.model_entities import ModelType, FetchFrom -from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \ - ConfigurateMethod +from core.model_runtime.entities.model_entities import FetchFrom, ModelType +from core.model_runtime.entities.provider_entities import (ConfigurateMethod, CredentialFormSchema, FormType, + ProviderEntity) from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.utils import encoders from extensions.ext_database import db -from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider +from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider +from pydantic import BaseModel logger = logging.getLogger(__name__) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 866b064a4e..83d85c34ae 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,10 +1,9 @@ from enum import Enum from typing import Optional -from pydantic import BaseModel - from core.model_runtime.entities.model_entities import ModelType from models.provider import ProviderQuotaType +from pydantic import BaseModel class QuotaUnit(Enum): diff --git a/api/core/entities/queue_entities.py b/api/core/entities/queue_entities.py index 20b434b05d..858b00ea64 100644 --- a/api/core/entities/queue_entities.py +++ b/api/core/entities/queue_entities.py @@ -1,9 +1,8 @@ from enum import Enum from typing import Any -from pydantic import BaseModel - from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from pydantic import BaseModel class QueueEvent(Enum): diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 8ce7edabf2..c244fe88f1 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,7 +1,6 @@ import os import requests - from models.api_based_extension import APIBasedExtensionPoint diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 6517e41ccd..29e892c58a 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -1,4 +1,4 @@ -from core.extension.extensible import ModuleExtension, ExtensionModule +from core.extension.extensible import ExtensionModule, ModuleExtension from core.external_data_tool.base import ExternalDataTool from core.moderation.base import Moderation diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index 1c181ff3c5..0db736f096 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -1,4 +1,4 @@ -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from typing import Optional from core.extension.extensible import Extensible, ExtensionModule diff --git a/api/core/features/agent_runner.py b/api/core/features/agent_runner.py index 2c63c66699..ba9c3218fa 100644 --- a/api/core/features/agent_runner.py +++ b/api/core/features/agent_runner.py @@ -1,19 +1,14 @@ import logging -from typing import cast, Optional, List - -from langchain import WikipediaAPIWrapper -from langchain.callbacks.base import BaseCallbackHandler -from langchain.tools import BaseTool, WikipediaQueryRun, Tool -from pydantic import BaseModel, Field +from typing import List, Optional, cast from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor +from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy from core.application_queue_manager import ApplicationQueueManager from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \ - AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity +from core.entities.application_entities import (AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity, InvokeFrom, + ModelConfigEntity) from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers import model_provider_factory @@ -21,11 +16,15 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.tool.current_datetime_tool import DatetimeTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.provider.serpapi_provider import SerpAPIToolProvider -from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput +from core.tool.serpapi_wrapper import OptimizedSerpAPIInput, OptimizedSerpAPIWrapper from core.tool.web_reader_tool import WebReaderTool from extensions.ext_database import db +from langchain import WikipediaAPIWrapper +from langchain.callbacks.base import BaseCallbackHandler +from langchain.tools import BaseTool, Tool, WikipediaQueryRun from models.dataset import Dataset from models.model import Message +from pydantic import BaseModel, Field logger = logging.getLogger(__name__) diff --git a/api/core/features/annotation_reply.py b/api/core/features/annotation_reply.py index 060a6c20c3..09945aaf6e 100644 --- a/api/core/features/annotation_reply.py +++ b/api/core/features/annotation_reply.py @@ -1,16 +1,15 @@ import logging from typing import Optional -from flask import current_app - from core.embedding.cached_embedding import CacheEmbedding from core.entities.application_entities import InvokeFrom from core.index.vector_index.vector_index import VectorIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from flask import current_app from models.dataset import Dataset -from models.model import App, Message, AppAnnotationSetting, MessageAnnotation +from models.model import App, AppAnnotationSetting, Message, MessageAnnotation from services.annotation_service import AppAnnotationService from services.dataset_service import DatasetCollectionBindingService diff --git a/api/core/features/dataset_retrieval.py b/api/core/features/dataset_retrieval.py index 3476bf9ad9..2accbafbdd 100644 --- a/api/core/features/dataset_retrieval.py +++ b/api/core/features/dataset_retrieval.py @@ -1,16 +1,15 @@ -from typing import cast, Optional, List +from typing import List, Optional, cast -from langchain.tools import BaseTool - -from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor +from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import DatasetEntity, ModelConfigEntity, InvokeFrom, DatasetRetrieveConfigEntity +from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db +from langchain.tools import BaseTool from models.dataset import Dataset diff --git a/api/core/features/external_data_fetch.py b/api/core/features/external_data_fetch.py index 272b7cee95..791fbf6ae3 100644 --- a/api/core/features/external_data_fetch.py +++ b/api/core/features/external_data_fetch.py @@ -1,14 +1,12 @@ import concurrent import json import logging - from concurrent.futures import ThreadPoolExecutor -from typing import Tuple, Optional - -from flask import current_app, Flask +from typing import Optional, Tuple from core.entities.application_entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory +from flask import Flask, current_app logger = logging.getLogger(__name__) diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 7f6e79b15b..3ebe531607 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -1,12 +1,11 @@ import enum from typing import Optional -from pydantic import BaseModel - from core.file.upload_file_parser import UploadFileParser from core.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db from models.model import UploadFile +from pydantic import BaseModel class FileType(enum.Enum): diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index faa1c8badb..8d205f93cb 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -1,12 +1,11 @@ -from typing import List, Union, Optional, Dict +from typing import Dict, List, Optional, Union import requests - -from core.file.file_obj import FileObj, FileType, FileTransferMethod +from core.file.file_obj import FileObj, FileTransferMethod, FileType from core.file.upload_file_parser import SUPPORT_EXTENSIONS from extensions.ext_database import db from models.account import Account -from models.model import MessageFile, EndUser, AppModelConfig, UploadFile +from models.model import AppModelConfig, EndUser, MessageFile, UploadFile class MessageFileParser: diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index 30125e0ea5..3eba8bb06c 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -6,9 +6,8 @@ import os import time from typing import Optional -from flask import current_app - from extensions.ext_storage import storage +from flask import current_app SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index acb7a6d2c9..2a15575360 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -1,17 +1,15 @@ import json import logging -from langchain.schema import OutputParserException - from core.model_manager import ModelManager -from core.model_runtime.entities.message_entities import UserPromptMessage, SystemPromptMessage +from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser - from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT +from langchain.schema import OutputParserException class LLMGenerator: diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 4d57d2d5fe..fcf293dc1c 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -2,7 +2,6 @@ import base64 from extensions.ext_database import db from libs import rsa - from models.account import Tenant diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 6537ef81fe..2273287cdc 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,12 +1,11 @@ import os from typing import Optional -from flask import Flask -from pydantic import BaseModel - from core.entities.provider_entities import QuotaUnit, RestrictModel from core.model_runtime.entities.model_entities import ModelType +from flask import Flask from models.provider import ProviderQuotaType +from pydantic import BaseModel class HostingQuota(BaseModel): diff --git a/api/core/index/base.py b/api/core/index/base.py index 166b2d65c0..33178ff83b 100644 --- a/api/core/index/base.py +++ b/api/core/index/base.py @@ -1,9 +1,9 @@ from __future__ import annotations -from abc import abstractmethod, ABC -from typing import List, Any -from langchain.schema import Document, BaseRetriever +from abc import ABC, abstractmethod +from typing import Any, List +from langchain.schema import BaseRetriever, Document from models.dataset import Dataset diff --git a/api/core/index/index.py b/api/core/index/index.py index ce11171d0c..56ce3c99c6 100644 --- a/api/core/index/index.py +++ b/api/core/index/index.py @@ -1,11 +1,10 @@ -from flask import current_app -from langchain.embeddings import OpenAIEmbeddings - from core.embedding.cached_embedding import CacheEmbedding -from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig +from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex from core.index.vector_index.vector_index import VectorIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from flask import current_app +from langchain.embeddings import OpenAIEmbeddings from models.dataset import Dataset diff --git a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py index db9fd027a0..fc07402206 100644 --- a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py +++ b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py @@ -2,9 +2,8 @@ import re from typing import Set import jieba -from jieba.analyse import default_tfidf - from core.index.keyword_table_index.stopwords import STOPWORDS +from jieba.analyse import default_tfidf class JiebaKeywordTableHandler: diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/index/keyword_table_index/keyword_table_index.py index d7f569b19c..06eef1ebf2 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/index/keyword_table_index/keyword_table_index.py @@ -1,14 +1,13 @@ import json from collections import defaultdict -from typing import Any, List, Optional, Dict - -from langchain.schema import Document, BaseRetriever -from pydantic import BaseModel, Field, Extra +from typing import Any, Dict, List, Optional from core.index.base import BaseIndex from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler from extensions.ext_database import db -from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable +from langchain.schema import BaseRetriever, Document +from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment +from pydantic import BaseModel, Extra, Field class KeywordTableConfig(BaseModel): diff --git a/api/core/index/vector_index/base.py b/api/core/index/vector_index/base.py index 1775ff3edf..ccc1833821 100644 --- a/api/core/index/vector_index/base.py +++ b/api/core/index/vector_index/base.py @@ -1,16 +1,16 @@ import json import logging from abc import abstractmethod -from typing import List, Any, cast - -from langchain.embeddings.base import Embeddings -from langchain.schema import Document, BaseRetriever -from langchain.vectorstores import VectorStore +from typing import Any, List, cast from core.index.base import BaseIndex from extensions.ext_database import db -from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding +from langchain.embeddings.base import Embeddings +from langchain.schema import BaseRetriever, Document +from langchain.vectorstores import VectorStore +from models.dataset import Dataset, DatasetCollectionBinding from models.dataset import Document as DatasetDocument +from models.dataset import DocumentSegment class BaseVectorIndex(BaseIndex): diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py index c175e5babb..67ba5a7b32 100644 --- a/api/core/index/vector_index/milvus_vector_index.py +++ b/api/core/index/vector_index/milvus_vector_index.py @@ -1,14 +1,13 @@ -from typing import cast, Any, List - -from langchain.embeddings.base import Embeddings -from langchain.schema import Document -from langchain.vectorstores import VectorStore -from pydantic import BaseModel, root_validator +from typing import Any, List, cast from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.milvus_vector_store import MilvusVectorStore +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +from langchain.vectorstores import VectorStore from models.dataset import Dataset +from pydantic import BaseModel, root_validator class MilvusConfig(BaseModel): diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py index e797134036..f755fe4101 100644 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ b/api/core/index/vector_index/qdrant_vector_index.py @@ -1,18 +1,17 @@ import os -from typing import Optional, Any, List, cast +from typing import Any, List, Optional, cast import qdrant_client -from langchain.embeddings.base import Embeddings -from langchain.schema import Document, BaseRetriever -from langchain.vectorstores import VectorStore -from pydantic import BaseModel -from qdrant_client.http.models import HnswConfigDiff - from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.qdrant_vector_store import QdrantVectorStore from extensions.ext_database import db +from langchain.embeddings.base import Embeddings +from langchain.schema import BaseRetriever, Document +from langchain.vectorstores import VectorStore from models.dataset import Dataset, DatasetCollectionBinding +from pydantic import BaseModel +from qdrant_client.http.models import HnswConfigDiff class QdrantConfig(BaseModel): diff --git a/api/core/index/vector_index/vector_index.py b/api/core/index/vector_index/vector_index.py index fe93fad110..0a69c4f734 100644 --- a/api/core/index/vector_index/vector_index.py +++ b/api/core/index/vector_index/vector_index.py @@ -1,10 +1,9 @@ import json -from flask import current_app -from langchain.embeddings.base import Embeddings - from core.index.vector_index.base import BaseVectorIndex from extensions.ext_database import db +from flask import current_app +from langchain.embeddings.base import Embeddings from models.dataset import Dataset, Document @@ -29,7 +28,7 @@ class VectorIndex: raise ValueError(f"Vector store must be specified.") if vector_type == "weaviate": - from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig + from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex return WeaviateVectorIndex( dataset=dataset, @@ -42,7 +41,7 @@ class VectorIndex: attributes=attributes ) elif vector_type == "qdrant": - from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig + from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex return QdrantVectorIndex( dataset=dataset, @@ -55,7 +54,7 @@ class VectorIndex: embeddings=embeddings ) elif vector_type == "milvus": - from core.index.vector_index.milvus_vector_index import MilvusVectorIndex, MilvusConfig + from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex return MilvusVectorIndex( dataset=dataset, diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py index 0ba8a20bca..b4add6c11a 100644 --- a/api/core/index/vector_index/weaviate_vector_index.py +++ b/api/core/index/vector_index/weaviate_vector_index.py @@ -1,16 +1,15 @@ -from typing import Optional, cast, Any, List +from typing import Any, List, Optional, cast import requests import weaviate -from langchain.embeddings.base import Embeddings -from langchain.schema import Document, BaseRetriever -from langchain.vectorstores import VectorStore -from pydantic import BaseModel, root_validator - from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.weaviate_vector_store import WeaviateVectorStore +from langchain.embeddings.base import Embeddings +from langchain.schema import BaseRetriever, Document +from langchain.vectorstores import VectorStore from models.dataset import Dataset +from pydantic import BaseModel, root_validator class WeaviateConfig(BaseModel): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index a55729b9ff..51c42bf75b 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,34 +5,34 @@ import re import threading import time import uuid -from typing import Optional, List, cast, Type, Union, Literal, AbstractSet, Collection, Any - -from flask import current_app, Flask -from flask_login import current_user -from langchain.schema import Document -from langchain.text_splitter import TextSplitter, TS, TokenTextSplitter -from sqlalchemy.orm.exc import ObjectDeletedError +from typing import AbstractSet, Any, Collection, List, Literal, Optional, Type, Union, cast from core.data_loader.file_extractor import FileExtractor from core.data_loader.loader.notion import NotionLoader from core.docstore.dataset_docstore import DatasetDocumentStore +from core.errors.error import ProviderTokenNotInitError from core.generator.llm_generator import LLMGenerator from core.index.index import IndexBuilder from core.model_manager import ModelManager -from core.errors.error import ProviderTokenNotInitError from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer -from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter +from core.spiltter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from flask import Flask, current_app +from flask_login import current_user +from langchain.schema import Document +from langchain.text_splitter import TS, TextSplitter, TokenTextSplitter from libs import helper +from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument -from models.dataset import Dataset, DocumentSegment, DatasetProcessRule +from models.dataset import DocumentSegment from models.model import UploadFile from models.source import DataSourceBinding +from sqlalchemy.orm.exc import ObjectDeletedError class IndexingRunner: diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 285cb3eeb1..663daa0856 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,7 +1,7 @@ from core.file.message_file_parser import MessageFileParser from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import PromptMessage, TextPromptMessageContent, UserPromptMessage, \ - AssistantPromptMessage, PromptMessageRole +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, + TextPromptMessageContent, UserPromptMessage) from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import model_provider_factory from extensions.ext_database import db diff --git a/api/core/model_manager.py b/api/core/model_manager.py index c732e40995..e75f624f2e 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,10 +1,10 @@ -from typing import Optional, Union, Generator, cast, List, IO +from typing import IO, Generator, List, Optional, Union, cast from core.entities.provider_configuration import ProviderModelBundle from core.errors.error import ProviderTokenNotInitError from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 144b779540..58150ef4da 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,8 +1,8 @@ from abc import ABC -from typing import Optional, List +from typing import List, Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.model_providers.__base.ai_model import AIModel _TEXT_COLOR_MAPPING = { diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 4bd86e81dc..e6268a7b09 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -1,11 +1,11 @@ import json import logging import sys -from typing import Optional, List +from typing import List, Optional from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult -from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index a3109ba585..b39427dccd 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -2,7 +2,6 @@ from typing import Dict from core.model_runtime.entities.model_entities import DefaultParameterName - PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { 'label': { diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index b5bd9e267a..76d4ef310e 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -2,10 +2,9 @@ from decimal import Decimal from enum import Enum from typing import Optional -from pydantic import BaseModel - from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo +from pydantic import BaseModel class LLMMode(Enum): diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 5aa0e19ef0..1504d2e5a5 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -2,9 +2,8 @@ from decimal import Decimal from enum import Enum from typing import Any, Optional -from pydantic import BaseModel - from core.model_runtime.entities.common_entities import I18nObject +from pydantic import BaseModel class ModelType(Enum): diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index bf3fe0878f..bd55d60795 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,10 +1,9 @@ from enum import Enum from typing import Optional -from pydantic import BaseModel - from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ModelType, ProviderModel, AIModelEntity +from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel +from pydantic import BaseModel class ConfigurateMethod(Enum): diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 7be3def379..499c76eb7d 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -1,8 +1,7 @@ from decimal import Decimal -from pydantic import BaseModel - from core.model_runtime.entities.model_entities import ModelUsage +from pydantic import BaseModel class EmbeddingUsage(ModelUsage): diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index b739016559..87ffc5896d 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -6,14 +6,13 @@ from abc import ABC, abstractmethod from typing import Optional import yaml -from pydantic import ValidationError - -from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from core.model_runtime.entities.model_entities import PriceInfo, AIModelEntity, PriceType, PriceConfig, \ - DefaultParameterName, FetchFrom, ModelType from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError +from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE +from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, ModelType, + PriceConfig, PriceInfo, PriceType) +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from pydantic import ValidationError class AIModel(ABC): diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 98f499c086..0bf6a385ac 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,15 +2,14 @@ import logging import os import time from abc import abstractmethod -from typing import Optional, Generator, Union, List +from typing import Generator, List, Optional, Union from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.logging_callback import LoggingCallback -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey, PriceType, ParameterType, ParameterRule, \ - ModelType -from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMUsage, \ - LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool +from core.model_runtime.entities.model_entities import (ModelPropertyKey, ModelType, ParameterRule, ParameterType, + PriceType) from core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 6c7aba2488..a856d42588 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -1,11 +1,10 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Optional, Dict +from typing import Dict, Optional import yaml - -from core.model_runtime.entities.model_entities import ModelType, AIModelEntity +from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py index 151bccc074..a084baf340 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py @@ -1,6 +1,6 @@ import os from abc import abstractmethod -from typing import Optional, IO +from typing import IO, Optional from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.ai_model import AIModel diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 93510b7825..6059b3f561 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -1,7 +1,8 @@ -from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer -from os.path import join, abspath, dirname -from typing import Any +from os.path import abspath, dirname, join from threading import Lock +from typing import Any + +from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer _tokenizer = None _lock = Lock() diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 425f3e5398..987f2fabf1 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,18 +1,16 @@ -from typing import Optional, Generator, Union, List +from typing import Generator, List, Optional, Union import anthropic from anthropic import Anthropic, Stream -from anthropic.types import completion_create_params, Completion -from httpx import Timeout - -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, AssistantPromptMessage, \ - SystemPromptMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError +from anthropic.types import Completion, completion_create_params +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from httpx import Timeout class AnthropicLargeLanguageModel(LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/azure_openai/_common.py b/api/core/model_runtime/model_providers/azure_openai/_common.py index db12dd6d83..627b487357 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_common.py +++ b/api/core/model_runtime/model_providers/azure_openai/_common.py @@ -1,10 +1,8 @@ import openai -from httpx import Timeout - +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPENAI_API_VERSION - -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError +from httpx import Timeout class _CommonAzureOpenAI: diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 0ccac24a1e..75c7ec508b 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -1,10 +1,9 @@ -from pydantic import BaseModel - -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \ - DefaultParameterName, PriceConfig, ModelPropertyKey -from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, I18nObject, + ModelFeature, ModelPropertyKey, ModelType, ParameterRule, + PriceConfig) +from pydantic import BaseModel AZURE_OPENAI_API_VERSION = '2023-12-01-preview' diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 72965a7613..55f0a9408f 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -1,23 +1,22 @@ import logging -from typing import Optional, Generator, Union, List, cast +from typing import Generator, List, Optional, Union, cast import tiktoken -from openai import AzureOpenAI, Stream -from openai.types import Completion -from openai.types.chat import ChatCompletionChunk, ChatCompletion, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaFunctionCall -from openai.types.chat.chat_completion_message import FunctionCall - -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, \ - LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, AssistantPromptMessage, \ - UserPromptMessage, PromptMessageContentType, ImagePromptMessageContent, \ - TextPromptMessageContent, SystemPromptMessage, ToolPromptMessage +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, + PromptMessage, PromptMessageContentType, PromptMessageTool, + SystemPromptMessage, TextPromptMessageContent, + ToolPromptMessage, UserPromptMessage) from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel +from openai import AzureOpenAI, Stream +from openai.types import Completion +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall +from openai.types.chat.chat_completion_message import FunctionCall logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index d71fe01ea2..06897a6c45 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -4,14 +4,13 @@ from typing import Optional, Tuple import numpy as np import tiktoken -from openai import AzureOpenAI - -from core.model_runtime.entities.model_entities import PriceType, AIModelEntity -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.model_entities import AIModelEntity, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_BASE_MODELS, AzureBaseModel +from openai import AzureOpenAI class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/baichuan/baichuan.py b/api/core/model_runtime/model_providers/baichuan/baichuan.py index 815b44c358..71bd6b5d92 100644 --- a/api/core/model_runtime/model_providers/baichuan/baichuan.py +++ b/api/core/model_runtime/model_providers/baichuan/baichuan.py @@ -1,7 +1,8 @@ -from core.model_runtime.model_providers.__base.model_provider import ModelProvider +import logging + from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError -import logging +from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py index ae1db90dc6..4562bb2be7 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py @@ -1,5 +1,6 @@ import re + class BaichuanTokenizer(object): @classmethod def count_chinese_characters(cls, text: str) -> int: diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index bbdca0ec2d..73081e67d4 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -1,12 +1,18 @@ -from os.path import join -from typing import List, Optional, Generator, Union, Dict, Any -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import \ - InsufficientAccountBalance, InvalidAPIKeyError, InternalServerError, RateLimitReachedError, InvalidAuthenticationError, BadRequestError from enum import Enum -from json import dumps, loads -from requests import post -from time import time from hashlib import md5 +from json import dumps, loads +from os.path import join +from time import time +from typing import Any, Dict, Generator, List, Optional, Union + +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (BadRequestError, + InsufficientAccountBalance, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError) +from requests import post + class BaichuanMessage: class Role(Enum): diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index 8b646cc765..c8bb1feb52 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -1,14 +1,21 @@ from typing import Generator, List, Optional, Union, cast -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import \ - InsufficientAccountBalance, InvalidAPIKeyError, InternalServerError, RateLimitReachedError, InvalidAuthenticationError, BadRequestError -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel, BaichuanMessage +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (BadRequestError, + InsufficientAccountBalance, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError) + class BaichuanLarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index b547f59d95..20aafea1eb 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -1,19 +1,22 @@ +import time +from json import dumps, loads from typing import Optional, Tuple from core.model_runtime.entities.model_entities import PriceType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.errors.invoke import InvokeError, InvokeConnectionError, InvokeServerUnavailableError, \ - InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import InvalidAPIKeyError, InsufficientAccountBalance, \ - InvalidAuthenticationError, RateLimitReachedError, InternalServerError, BadRequestError - +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (BadRequestError, + InsufficientAccountBalance, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError) from requests import post -from json import dumps, loads -import time class BaichuanTextEmbeddingModel(TextEmbeddingModel): """ diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index 6884ede2bc..44868fcf73 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -1,30 +1,22 @@ -from typing import Generator, List, Optional -from requests import post - - -from os.path import join -from typing import cast +import logging from json import dumps +from os.path import join +from typing import Generator, List, Optional, cast -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, AssistantPromptMessage, \ - SystemPromptMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction, + PromptMessageTool, SystemPromptMessage, UserPromptMessage) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, AssistantPromptMessage, \ - PromptMessageFunction, UserPromptMessage, SystemPromptMessage from core.model_runtime.utils import helper -from openai import OpenAI, Stream, \ - APIConnectionError, APITimeoutError, AuthenticationError, InternalServerError, \ - RateLimitError, ConflictError, NotFoundError, UnprocessableEntityError, PermissionDeniedError -from openai.types.chat import ChatCompletionChunk, ChatCompletion -from openai.types.chat.chat_completion_message import FunctionCall from httpx import Timeout - -import logging +from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, + NotFoundError, OpenAI, PermissionDeniedError, RateLimitError, Stream, UnprocessableEntityError) +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat.chat_completion_message import FunctionCall +from requests import post logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index 24ac98d476..8c82cce766 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -1,10 +1,9 @@ from typing import Optional import cohere - -from core.model_runtime.entities.rerank_entities import RerankResult, RerankDocument -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 86e87e8c1e..6fd5c9144c 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,24 +1,21 @@ -from typing import Optional, Generator, Union, List +import logging +from typing import Generator, List, Optional, Union -import google.generativeai as genai import google.api_core.exceptions as exceptions +import google.generativeai as genai import google.generativeai.client as client -from google.generativeai.types import HarmCategory, HarmBlockThreshold - -from google.generativeai.types import GenerateContentResponse, ContentType -from google.generativeai.types.content_types import to_part - -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, AssistantPromptMessage, \ - SystemPromptMessage, PromptMessageRole, PromptMessageContentType -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, + PromptMessageContentType, PromptMessageRole, + PromptMessageTool, SystemPromptMessage, UserPromptMessage) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers import google from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory +from google.generativeai.types.content_types import to_part -import logging logger = logging.getLogger(__name__) class GoogleLargeLanguageModel(LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index 4a2c35fec1..1140c947b9 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -1,6 +1,5 @@ -from huggingface_hub.utils import HfHubHTTPError, BadRequestError - from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError +from huggingface_hub.utils import BadRequestError, HfHubHTTPError class _CommonHuggingfaceHub: diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 8f4b9d903c..f3d5a853d7 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -1,19 +1,18 @@ -from typing import Optional, List, Union, Generator - -from huggingface_hub import InferenceClient -from huggingface_hub.hf_api import HfApi -from huggingface_hub.utils import BadRequestError +from typing import Generator, List, Optional, Union from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMMode -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \ - UserPromptMessage, SystemPromptMessage -from core.model_runtime.entities.model_entities import ParameterRule, DefaultParameterName, AIModelEntity, ModelType, \ - FetchFrom, ModelPropertyKey +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, + ModelPropertyKey, ModelType, ParameterRule) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub +from huggingface_hub import InferenceClient +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import BadRequestError class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index afec45ff0a..f0dc632fae 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -4,15 +4,13 @@ from typing import Optional import numpy as np import requests -from huggingface_hub import InferenceClient, HfApi - from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub - +from huggingface_hub import HfApi, InferenceClient HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/' diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index dd0b0b7b25..c388341d51 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -1,5 +1,7 @@ +from os.path import abspath, dirname, join + from transformers import AutoTokenizer -from os.path import join, abspath, dirname + class JinaTokenizer: @staticmethod diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 7e549a512f..7cd1c3e593 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -1,17 +1,16 @@ +import time +from json import JSONDecodeError, dumps from typing import Optional from core.model_runtime.entities.model_entities import PriceType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.errors.invoke import InvokeError, InvokeConnectionError, InvokeServerUnavailableError, \ - InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError from core.model_runtime.model_providers.jina.text_embedding.jina_tokenizer import JinaTokenizer - from requests import post -from json import dumps, JSONDecodeError -import time class JinaTextEmbeddingModel(TextEmbeddingModel): """ diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index e9348fa114..117ef8c399 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -1,22 +1,24 @@ -from typing import Generator, List, Optional, Union, cast -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage -from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from openai import OpenAI, Stream, \ - APIConnectionError, APITimeoutError, AuthenticationError, InternalServerError, \ - RateLimitError, ConflictError, NotFoundError, UnprocessableEntityError, PermissionDeniedError -from openai.types.chat import ChatCompletionChunk, ChatCompletion -from openai.types.completion import Completion -from openai.types.chat.chat_completion_message import FunctionCall -from httpx import Timeout from os.path import join +from typing import Generator, List, Optional, Union, cast +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, + ParameterRule, ParameterType) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils import helper +from httpx import Timeout +from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, + NotFoundError, OpenAI, PermissionDeniedError, RateLimitError, Stream, UnprocessableEntityError) +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.completion import Completion + class LocalAILarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index 82f0715062..511f09e3e7 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -1,17 +1,16 @@ +import time +from json import JSONDecodeError, dumps +from os.path import join from typing import Optional from core.model_runtime.entities.model_entities import PriceType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.errors.invoke import InvokeError, InvokeConnectionError, InvokeServerUnavailableError, \ - InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError - from requests import post -from json import dumps, JSONDecodeError -from os.path import join -import time class LocalAITextEmbeddingModel(TextEmbeddingModel): """ diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 1c00484977..c35100ec07 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -1,12 +1,14 @@ -from core.model_runtime.model_providers.minimax.llm.errors import BadRequestError, InvalidAPIKeyError, \ - InternalServerError, RateLimitReachedError, InvalidAuthenticationError, InsufficientAccountBalanceError -from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage -from typing import List, Dict, Any, Generator, Union - -from json import dumps, loads -from requests import post, Response -from time import time from hashlib import md5 +from json import dumps, loads +from time import time +from typing import Any, Dict, Generator, List, Union + +from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, + InternalServerError, InvalidAPIKeyError, + InvalidAuthenticationError, RateLimitReachedError) +from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage +from requests import Response, post + class MinimaxChatCompletion(object): """ diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 08d4ab585c..a680ef0a2b 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -1,12 +1,14 @@ -from core.model_runtime.model_providers.minimax.llm.errors import BadRequestError, InvalidAPIKeyError, \ - InternalServerError, RateLimitReachedError, InvalidAuthenticationError, InsufficientAccountBalanceError -from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage -from typing import List, Dict, Any, Generator, Union - -from json import dumps, loads -from requests import post, Response -from time import time from hashlib import md5 +from json import dumps, loads +from time import time +from typing import Any, Dict, Generator, List, Union + +from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, + InternalServerError, InvalidAPIKeyError, + InvalidAuthenticationError, RateLimitReachedError) +from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage +from requests import Response, post + class MinimaxChatCompletionPro(object): """ diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index 96557418a2..8937b1c128 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -1,17 +1,20 @@ from typing import Generator, List, Optional, Union -from core.model_runtime.model_providers.minimax.llm.errors import BadRequestError, InvalidAPIKeyError, \ - InternalServerError, RateLimitReachedError, InvalidAuthenticationError, InsufficientAccountBalanceError -from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, ParameterRule, ParameterType +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.minimax.llm.chat_completion import MinimaxChatCompletion from core.model_runtime.model_providers.minimax.llm.chat_completion_pro import MinimaxChatCompletionPro +from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, + InternalServerError, InvalidAPIKeyError, + InvalidAuthenticationError, RateLimitReachedError) +from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage -from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError -from core.model_runtime.errors.validate import CredentialsValidateFailedError class MinimaxLargeLanguageModel(LargeLanguageModel): model_apis = { diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index 3555b4d7ae..6d1e8e64d8 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -1,5 +1,6 @@ -from typing import Dict, Any from enum import Enum +from typing import Any, Dict + class MinimaxMessage: class Role(Enum): diff --git a/api/core/model_runtime/model_providers/minimax/minimax.py b/api/core/model_runtime/model_providers/minimax/minimax.py index 2dd844c00e..97afe6aa44 100644 --- a/api/core/model_runtime/model_providers/minimax/minimax.py +++ b/api/core/model_runtime/model_providers/minimax/minimax.py @@ -1,7 +1,8 @@ -from core.model_runtime.model_providers.__base.model_provider import ModelProvider +import logging + from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError -import logging +from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index cf8b403d0e..65f2a9a225 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -1,18 +1,18 @@ +import time +from json import dumps, loads from typing import Optional from core.model_runtime.entities.model_entities import PriceType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.errors.invoke import InvokeError, InvokeConnectionError, InvokeServerUnavailableError, \ - InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError -from core.model_runtime.model_providers.minimax.llm.errors import InvalidAPIKeyError, InsufficientAccountBalanceError, \ - InvalidAuthenticationError, RateLimitReachedError, InternalServerError, BadRequestError - +from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, + InternalServerError, InvalidAPIKeyError, + InvalidAuthenticationError, RateLimitReachedError) from requests import post -from json import dumps, loads -import time class MinimaxTextEmbeddingModel(TextEmbeddingModel): """ diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 1435ab89b1..375017c563 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -5,13 +5,12 @@ from collections import OrderedDict from typing import Optional import yaml -from pydantic import BaseModel - from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import SimpleProviderEntity, ProviderConfig, ProviderEntity +from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator +from pydantic import BaseModel logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 12cf4d81b8..91705c3ba8 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -1,9 +1,8 @@ import openai +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from httpx import Timeout -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError - class _CommonOpenAI: def _to_credential_kwargs(self, credentials: dict) -> dict: diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 17e5910b92..97397c2274 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,24 +1,23 @@ import logging -from typing import Optional, Generator, Union, List, cast +from typing import Generator, List, Optional, Union, cast import tiktoken -from openai import OpenAI, Stream -from openai.types import Completion -from openai.types.chat import ChatCompletionChunk, ChatCompletion, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaFunctionCall -from openai.types.chat.chat_completion_message import FunctionCall - -from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, AssistantPromptMessage, \ - PromptMessageFunction, UserPromptMessage, PromptMessageContentType, ImagePromptMessageContent, \ - TextPromptMessageContent, SystemPromptMessage, ToolPromptMessage -from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject, ModelType, FetchFrom, \ - PriceConfig, AIModelEntity, FetchFrom -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, \ - LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, + PromptMessage, PromptMessageContentType, + PromptMessageFunction, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, + UserPromptMessage) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI from core.model_runtime.utils import helper +from openai import OpenAI, Stream +from openai.types import Completion +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall +from openai.types.chat.chat_completion_message import FunctionCall logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index b1d0e57ad2..2a0901d752 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -1,12 +1,11 @@ from typing import Optional -from openai import OpenAI -from openai.types import ModerationCreateResponse - from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.moderation_model import ModerationModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI +from openai import OpenAI +from openai.types import ModerationCreateResponse class OpenAIModerationModel(_CommonOpenAI, ModerationModel): diff --git a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py index dfb1396c11..b2b337a563 100644 --- a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py @@ -1,10 +1,9 @@ -from typing import Optional, IO - -from openai import OpenAI +from typing import IO, Optional from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI +from openai import OpenAI class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index 6c4c4a7eb9..cde354e861 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -4,13 +4,12 @@ from typing import Optional, Tuple import numpy as np import tiktoken -from openai import OpenAI - from core.model_runtime.entities.model_entities import PriceType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI +from openai import OpenAI class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index bf00caabd0..9b7b052b99 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -1,13 +1,13 @@ from decimal import Decimal -import requests +import requests from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import AIModelEntity, DefaultParameterName, \ - FetchFrom, ModelPropertyKey, ModelType, ParameterRule, ParameterType, PriceConfig - -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, \ - InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError, InvokeError +from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, + ModelPropertyKey, ModelType, ParameterRule, ParameterType, + PriceConfig) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) class _CommonOAI_API_Compat: diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index acb974b050..e73f47b8e8 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -1,27 +1,24 @@ +import json import logging from decimal import Decimal +from typing import Generator, List, Optional, Union, cast from urllib.parse import urljoin import requests -import json - -from typing import Optional, Generator, Union, List, cast - from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.utils import helper - -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \ - AssistantPromptMessage, PromptMessageContent, \ - PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \ - ToolPromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \ - DefaultParameterName, \ - ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, + PromptMessage, PromptMessageContent, PromptMessageContentType, + PromptMessageFunction, PromptMessageTool, SystemPromptMessage, + ToolPromptMessage, UserPromptMessage) +from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, + ModelPropertyKey, ModelType, ParameterRule, ParameterType, + PriceConfig) from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.utils import helper logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 19ec73d109..b735fdb792 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -1,16 +1,15 @@ +import json import time from decimal import Decimal from typing import Optional from urllib.parse import urljoin -import requests -import json import numpy as np - +import requests from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import PriceType, ModelPropertyKey, ModelType, AIModelEntity, FetchFrom, \ - PriceConfig -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, + PriceConfig, PriceType) +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index 609ea19b59..af62ddf92f 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -1,16 +1,23 @@ from typing import Generator, List, Optional, Union -from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import BadRequestError, InvalidAPIKeyError, \ - InternalServerError, RateLimitReachedError, InvalidAuthenticationError, InsufficientAccountBalanceError -from core.model_runtime.model_providers.openllm.llm.openllm_generate import OpenLLMGenerate, OpenLLMGenerateMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage -from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, + ParameterRule, ParameterType) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.openllm.llm.openllm_generate import OpenLLMGenerate, OpenLLMGenerateMessage +from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import (BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError) + class OpenLLMLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 7d9bc349be..f14f4cd646 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -1,11 +1,16 @@ -from typing import Any, Dict, List, Union, Generator -from requests import post, Response -from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema -from json import dumps, loads from enum import Enum +from json import dumps, loads +from typing import Any, Dict, Generator, List, Union + +from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import (BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError) +from requests import Response, post +from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema -from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import BadRequestError, InvalidAPIKeyError, \ - InternalServerError, RateLimitReachedError, InvalidAuthenticationError, InsufficientAccountBalanceError class OpenLLMGenerateMessage: class Role(Enum): diff --git a/api/core/model_runtime/model_providers/openllm/openllm.py b/api/core/model_runtime/model_providers/openllm/openllm.py index b1d75f7de2..8014802144 100644 --- a/api/core/model_runtime/model_providers/openllm/openllm.py +++ b/api/core/model_runtime/model_providers/openllm/openllm.py @@ -1,6 +1,7 @@ -from core.model_runtime.model_providers.__base.model_provider import ModelProvider import logging +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index 65a6834ca8..2f30427d36 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -1,17 +1,16 @@ +import time +from json import dumps, loads from typing import Optional from core.model_runtime.entities.model_entities import PriceType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.errors.invoke import InvokeError, InvokeConnectionError, InvokeServerUnavailableError, \ - InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError - from requests import post -from requests.exceptions import InvalidSchema, MissingSchema, ConnectionError -from json import dumps, loads +from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema -import time class OpenLLMTextEmbeddingModel(TextEmbeddingModel): """ diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index 1a14d29cea..ad130cabbc 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -1,6 +1,5 @@ -from replicate.exceptions import ReplicateError, ModelError - from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError +from replicate.exceptions import ModelError, ReplicateError class _CommonReplicate: diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 54134feca9..69c0a82636 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -1,17 +1,17 @@ -from typing import Optional, List, Union, Generator - -from replicate import Client as ReplicateClient -from replicate.exceptions import ReplicateError -from replicate.prediction import Prediction +from typing import Generator, List, Optional, Union from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \ - PromptMessageRole, UserPromptMessage, SystemPromptMessage -from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType, ModelPropertyKey +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, + PromptMessageTool, SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, + ParameterRule) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.replicate._common import _CommonReplicate +from replicate import Client as ReplicateClient +from replicate.exceptions import ReplicateError +from replicate.prediction import Prediction class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 3d6fdc74a7..37a275614c 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -2,14 +2,13 @@ import json import time from typing import Optional -from replicate import Client as ReplicateClient - from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.replicate._common import _CommonReplicate +from replicate import Client as ReplicateClient class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/spark/llm/_client.py b/api/core/model_runtime/model_providers/spark/llm/_client.py index d3e32b0259..37ae40c1d8 100644 --- a/api/core/model_runtime/model_providers/spark/llm/_client.py +++ b/api/core/model_runtime/model_providers/spark/llm/_client.py @@ -4,12 +4,11 @@ import hashlib import hmac import json import queue -from typing import Optional -from urllib.parse import urlparse import ssl from datetime import datetime from time import mktime -from urllib.parse import urlencode +from typing import Optional +from urllib.parse import urlencode, urlparse from wsgiref.handlers import format_date_time import websocket diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 56f13cdbc3..33475f5769 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -1,12 +1,11 @@ import threading -from typing import Optional, Generator, Union, List +from typing import Generator, List, Optional, Union -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, AssistantPromptMessage, \ - SystemPromptMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index f2c74b808b..89198fe4b0 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -1,9 +1,11 @@ from typing import Generator, List, Optional, Union + from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): def _update_endpoint_url(self, credentials: dict): diff --git a/api/core/model_runtime/model_providers/tongyi/llm/_client.py b/api/core/model_runtime/model_providers/tongyi/llm/_client.py index c8241fe084..2aab69af7a 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/_client.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/_client.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List, Optional +from typing import Any, Dict, List, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms import Tongyi diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index b33426353c..5cc05db0fb 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -1,23 +1,22 @@ from http import HTTPStatus -from typing import Optional, Generator, Union, List +from typing import Generator, List, Optional, Union + import dashscope -from dashscope.api_entities.dashscope_response import DashScopeAPIResponse -from dashscope.common.error import AuthenticationError, RequestFailure, \ - InvalidParameter, UnsupportedModel, ServiceUnavailableError, UnsupportedHTTPMethod - -from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry - -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, AssistantPromptMessage, \ - SystemPromptMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dashscope.api_entities.dashscope_response import DashScopeAPIResponse +from dashscope.common.error import (AuthenticationError, InvalidParameter, RequestFailure, ServiceUnavailableError, + UnsupportedHTTPMethod, UnsupportedModel) +from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry from ._client import EnhanceTongyi + class TongyiLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index c70a4daedf..65081a9665 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,12 +1,15 @@ +from datetime import datetime, timedelta from enum import Enum from json import dumps, loads -from requests import post, Response -from typing import Any, Dict, Union, Generator, List -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import BadRequestError, InvalidAPIKeyError, \ - InternalServerError, RateLimitReachedError, InvalidAuthenticationError -from core.model_runtime.entities.message_entities import PromptMessageTool -from datetime import datetime, timedelta from threading import Lock +from typing import Any, Dict, Generator, List, Union + +from core.model_runtime.entities.message_entities import PromptMessageTool +from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (BadRequestError, InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError) +from requests import Response, post # map api_key to access_token baidu_access_tokens: Dict[str, 'BaiduAccessToken'] = {} diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index 0c35bf7ca4..27b2bce9af 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -1,13 +1,18 @@ from typing import Generator, List, Optional, Union, cast -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage, BaiduAccessToken -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import \ - InsufficientAccountBalance, InvalidAPIKeyError, InternalServerError, RateLimitReachedError, InvalidAuthenticationError, BadRequestError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage +from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (BadRequestError, InsufficientAccountBalance, + InternalServerError, InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError) + class ErnieBotLarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.py b/api/core/model_runtime/model_providers/wenxin/wenxin.py index ceddd73707..04845d06bc 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.py @@ -1,7 +1,8 @@ -from core.model_runtime.model_providers.__base.model_provider import ModelProvider +import logging + from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError -import logging +from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index c32d9a3d8e..8f068f564d 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -1,29 +1,27 @@ -from typing import Generator, List, Optional, Union, Iterator, cast -from openai import OpenAI -from openai.types.chat import ChatCompletionChunk, ChatCompletion -from openai.types.completion import Completion -from openai.types.chat.chat_completion_message import FunctionCall -from openai.types.chat import ChatCompletionChunk, ChatCompletion, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaFunctionCall -from openai import OpenAI, Stream, \ - APIConnectionError, APITimeoutError, AuthenticationError, InternalServerError, \ - RateLimitError, ConflictError, NotFoundError, UnprocessableEntityError, PermissionDeniedError +from typing import Generator, Iterator, List, Optional, Union, cast -from xinference_client.client.restful.restful_client import \ - RESTfulChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatglmCppChatModelHandle, Client - -from core.model_runtime.entities.model_entities import AIModelEntity - -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, SystemPromptMessage, AssistantPromptMessage from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType, ModelPropertyKey +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, + ParameterRule, ParameterType) +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.xinference.llm.xinference_helper import XinferenceHelper, XinferenceModelExtraParameter -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper, + XinferenceModelExtraParameter) from core.model_runtime.utils import helper +from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, + NotFoundError, OpenAI, PermissionDeniedError, RateLimitError, Stream, UnprocessableEntityError) +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall +from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.completion import Completion +from xinference_client.client.restful.restful_client import (Client, RESTfulChatglmCppChatModelHandle, + RESTfulChatModelHandle, RESTfulGenerateModelHandle) + class XinferenceAILargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], diff --git a/api/core/model_runtime/model_providers/xinference/llm/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/llm/xinference_helper.py index d73fab29e5..88b5a558ac 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/llm/xinference_helper.py @@ -1,11 +1,13 @@ -from requests import get -from requests.sessions import Session -from requests.adapters import HTTPAdapter -from requests.exceptions import MissingSchema, ConnectionError, Timeout -from time import time from threading import Lock +from time import time from typing import List +from requests import get +from requests.adapters import HTTPAdapter +from requests.exceptions import ConnectionError, MissingSchema, Timeout +from requests.sessions import Session + + class XinferenceModelExtraParameter(object): model_format: str model_handle_type: str diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index cb3865abd7..9ec9e09aa0 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -1,14 +1,14 @@ from typing import Optional -from core.model_runtime.entities.rerank_entities import RerankResult, RerankDocument -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel -from core.model_runtime.entities.model_entities import FetchFrom, ModelType, AIModelEntity -from core.model_runtime.entities.common_entities import I18nObject +from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle -from xinference_client.client.restful.restful_client import RESTfulRerankModelHandle, Client class XinferenceRerankModel(RerankModel): """ diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index d8aba98098..e7d7959417 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -1,16 +1,15 @@ +import time from typing import Optional -from core.model_runtime.entities.model_entities import PriceType, FetchFrom, ModelType, AIModelEntity from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.errors.invoke import InvokeError, InvokeConnectionError, InvokeServerUnavailableError, \ - InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError +from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle -from xinference_client.client.restful.restful_client import RESTfulEmbeddingModelHandle, RESTfulModelHandle, Client - -import time class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ diff --git a/api/core/model_runtime/model_providers/zhipuai/_client.py b/api/core/model_runtime/model_providers/zhipuai/_client.py index 0072366c98..31042d318d 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/_client.py @@ -4,8 +4,7 @@ from __future__ import annotations import logging import posixpath -from pydantic import Extra, BaseModel - +from pydantic import BaseModel, Extra from zhipuai.model_api.api import InvokeType from zhipuai.utils import jwt_token from zhipuai.utils.http_client import post, stream diff --git a/api/core/model_runtime/model_providers/zhipuai/_common.py b/api/core/model_runtime/model_providers/zhipuai/_common.py index 19bf82b0c7..d1479ae92f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_common.py +++ b/api/core/model_runtime/model_providers/zhipuai/_common.py @@ -1,5 +1,5 @@ -from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ - InvokeAuthorizationError, InvokeBadRequestError, InvokeError +from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, + InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) class _CommonZhipuaiAI: diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 6624a41e83..e8e4019ea5 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,18 +1,9 @@ import json -from typing import ( - Any, - Dict, - List, - Optional, - Generator, - Union -) +from typing import Any, Dict, Generator, List, Optional, Union -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, \ - AssistantPromptMessage, \ - SystemPromptMessage, PromptMessageRole -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, + PromptMessageTool, SystemPromptMessage, UserPromptMessage) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 308723c457..0fd04134b3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -1,14 +1,13 @@ import time -from typing import Optional, List, Tuple - -from langchain.schema.language_model import _get_token_ids_default_method +from typing import List, Optional, Tuple from core.model_runtime.entities.model_entities import PriceType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI +from langchain.schema.language_model import _get_token_ids_default_method class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index 54031a6066..3e6f3526ef 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -1,6 +1,6 @@ from typing import Optional -from core.model_runtime.entities.provider_entities import FormType, CredentialFormSchema +from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType class CommonValidator: diff --git a/api/core/model_runtime/utils/_compat.py b/api/core/model_runtime/utils/_compat.py index 1cb8b0d2cb..305edcac8f 100644 --- a/api/core/model_runtime/utils/_compat.py +++ b/api/core/model_runtime/utils/_compat.py @@ -1,6 +1,4 @@ -from typing import ( - Any -) +from typing import Any from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index 288cf65150..2220b40d73 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -3,14 +3,7 @@ import datetime from collections import defaultdict, deque from decimal import Decimal from enum import Enum -from ipaddress import ( - IPv4Address, - IPv4Interface, - IPv4Network, - IPv6Address, - IPv6Interface, - IPv6Network, -) +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network from pathlib import Path, PurePath from re import Pattern from types import GeneratorType diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 9ef584cd1a..82b2f27234 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,10 +1,9 @@ -from pydantic import BaseModel - -from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction -from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor, APIBasedExtensionPoint +from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor from core.helper.encrypter import decrypt_token +from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult from extensions.ext_database import db from models.api_based_extension import APIBasedExtension +from pydantic import BaseModel class ModerationInputParams(BaseModel): diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index ce4e574038..1cce8f18f2 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import Optional -from pydantic import BaseModel from enum import Enum +from typing import Optional from core.extension.extensible import Extensible, ExtensionModule +from pydantic import BaseModel class ModerationAction(Enum): diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 168b9d43f8..b4f178bfb9 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,4 +1,4 @@ -from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction +from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult class KeywordsModeration(Moderation): diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index e2c1fdf81d..bc868c2d52 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,6 +1,6 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction +from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult class OpenAIModeration(Moderation): diff --git a/api/core/prompt/output_parser/rule_config_generator.py b/api/core/prompt/output_parser/rule_config_generator.py index 84df4d0c34..61165d628e 100644 --- a/api/core/prompt/output_parser/rule_config_generator.py +++ b/api/core/prompt/output_parser/rule_config_generator.py @@ -1,7 +1,7 @@ from typing import Any -from langchain.schema import BaseOutputParser, OutputParserException from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE +from langchain.schema import BaseOutputParser, OutputParserException from libs.json_in_md_parser import parse_and_check_json_markdown diff --git a/api/core/prompt/output_parser/suggested_questions_after_answer.py b/api/core/prompt/output_parser/suggested_questions_after_answer.py index d8bb0809cf..49501a2dd7 100644 --- a/api/core/prompt/output_parser/suggested_questions_after_answer.py +++ b/api/core/prompt/output_parser/suggested_questions_after_answer.py @@ -2,10 +2,9 @@ import json import re from typing import Any -from langchain.schema import BaseOutputParser - from core.model_runtime.errors.invoke import InvokeError from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +from langchain.schema import BaseOutputParser class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 19c8e2d5ad..01cad0c1d4 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,15 +1,16 @@ +import enum import json import os import re -import enum from typing import List, Optional, Tuple, cast -from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, \ - AdvancedCompletionPromptTemplateEntity +from core.entities.application_entities import (AdvancedCompletionPromptTemplateEntity, ModelConfigEntity, + PromptTemplateEntity) from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage, \ - TextPromptMessageContent, PromptMessageRole, AssistantPromptMessage +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, + SystemPromptMessage, TextPromptMessageContent, + UserPromptMessage) from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_builder import PromptBuilder diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 1265c4e423..12176dd608 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,22 +3,21 @@ from collections import defaultdict from json import JSONDecodeError from typing import Optional -from sqlalchemy.exc import IntegrityError - from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity -from core.entities.provider_configuration import ProviderConfigurations, ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \ - SystemConfiguration, QuotaConfiguration +from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle +from core.entities.provider_entities import (CustomConfiguration, CustomModelConfiguration, CustomProviderConfiguration, + QuotaConfiguration, SystemConfiguration) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \ - ConfigurateMethod +from core.model_runtime.entities.provider_entities import (ConfigurateMethod, CredentialFormSchema, FormType, + ProviderEntity) from core.model_runtime.model_providers import model_provider_factory from extensions import ext_hosting_provider from extensions.ext_database import db -from models.provider import TenantDefaultModel, Provider, ProviderModel, ProviderQuotaType, ProviderType, \ - TenantPreferredModelProvider +from models.provider import (Provider, ProviderModel, ProviderQuotaType, ProviderType, TenantDefaultModel, + TenantPreferredModelProvider) +from sqlalchemy.exc import IntegrityError class ProviderManager: diff --git a/api/core/rerank/rerank.py b/api/core/rerank/rerank.py index 984cdb4003..4d2f84b492 100644 --- a/api/core/rerank/rerank.py +++ b/api/core/rerank/rerank.py @@ -1,8 +1,7 @@ from typing import List, Optional -from langchain.schema import Document - from core.model_manager import ModelInstance +from langchain.schema import Document class RerankRunner: diff --git a/api/core/spiltter/fixed_text_splitter.py b/api/core/spiltter/fixed_text_splitter.py index bddaad2920..80d609d800 100644 --- a/api/core/spiltter/fixed_text_splitter.py +++ b/api/core/spiltter/fixed_text_splitter.py @@ -1,15 +1,12 @@ """Functionality for splitting text.""" from __future__ import annotations -from typing import ( - Any, - List, - Optional, -) - -from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter, TS, Type, Union, AbstractSet, Literal, Collection +from typing import Any, List, Optional from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from langchain.text_splitter import (TS, AbstractSet, Collection, Literal, RecursiveCharacterTextSplitter, + TokenTextSplitter, Type, Union) + class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): """ diff --git a/api/core/third_party/langchain/llms/fake.py b/api/core/third_party/langchain/llms/fake.py index 6448e506bf..64117477e1 100644 --- a/api/core/third_party/langchain/llms/fake.py +++ b/api/core/third_party/langchain/llms/fake.py @@ -1,9 +1,9 @@ import time -from typing import List, Optional, Any, Mapping +from typing import Any, List, Mapping, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import SimpleChatModel -from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration +from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult class FakeLLM(SimpleChatModel): diff --git a/api/core/third_party/spark/spark_llm.py b/api/core/third_party/spark/spark_llm.py index b11e7c20b5..637c1bc740 100644 --- a/api/core/third_party/spark/spark_llm.py +++ b/api/core/third_party/spark/spark_llm.py @@ -4,12 +4,11 @@ import hashlib import hmac import json import queue -from typing import Optional -from urllib.parse import urlparse import ssl from datetime import datetime from time import mktime -from urllib.parse import urlencode +from typing import Optional +from urllib.parse import urlencode, urlparse from wsgiref.handlers import format_date_time import websocket diff --git a/api/core/tool/current_datetime_tool.py b/api/core/tool/current_datetime_tool.py index 874188fc3d..3bb2bb5eaa 100644 --- a/api/core/tool/current_datetime_tool.py +++ b/api/core/tool/current_datetime_tool.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Type from langchain.tools import BaseTool -from pydantic import Field, BaseModel +from pydantic import BaseModel, Field class DatetimeToolInput(BaseModel): diff --git a/api/core/tool/dataset_multi_retriever_tool.py b/api/core/tool/dataset_multi_retriever_tool.py index aa00cac573..97cf7c9dfa 100644 --- a/api/core/tool/dataset_multi_retriever_tool.py +++ b/api/core/tool/dataset_multi_retriever_tool.py @@ -1,20 +1,19 @@ import json import threading -from typing import Type, Optional, List - -from flask import current_app, Flask -from langchain.tools import BaseTool -from pydantic import Field, BaseModel +from typing import List, Optional, Type from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.embedding.cached_embedding import CacheEmbedding -from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rerank.rerank import RerankRunner from extensions.ext_database import db -from models.dataset import Dataset, DocumentSegment, Document +from flask import Flask, current_app +from langchain.tools import BaseTool +from models.dataset import Dataset, Document, DocumentSegment +from pydantic import BaseModel, Field from services.retrieval_service import RetrievalService default_retrieval_model = { diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index 2cf779a488..9049b5e691 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -1,19 +1,18 @@ import threading -from typing import Type, Optional, List - -from flask import current_app -from langchain.tools import BaseTool -from pydantic import Field, BaseModel +from typing import List, Optional, Type from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.embedding.cached_embedding import CacheEmbedding -from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig +from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rerank.rerank import RerankRunner from extensions.ext_database import db -from models.dataset import Dataset, DocumentSegment, Document +from flask import current_app +from langchain.tools import BaseTool +from models.dataset import Dataset, Document, DocumentSegment +from pydantic import BaseModel, Field from services.retrieval_service import RetrievalService default_retrieval_model = { diff --git a/api/core/tool/serpapi_wrapper.py b/api/core/tool/serpapi_wrapper.py index e9889bc97e..0c3f107d94 100644 --- a/api/core/tool/serpapi_wrapper.py +++ b/api/core/tool/serpapi_wrapper.py @@ -1,5 +1,5 @@ from langchain import SerpAPIWrapper -from pydantic import Field, BaseModel +from pydantic import BaseModel, Field class OptimizedSerpAPIInput(BaseModel): diff --git a/api/core/tool/web_reader_tool.py b/api/core/tool/web_reader_tool.py index 04cbe0cca3..18a0e93721 100644 --- a/api/core/tool/web_reader_tool.py +++ b/api/core/tool/web_reader_tool.py @@ -7,10 +7,14 @@ import subprocess import tempfile import unicodedata from contextlib import contextmanager -from typing import Type, Any +from typing import Any, Type import requests -from bs4 import BeautifulSoup, NavigableString, Comment, CData +from bs4 import BeautifulSoup, CData, Comment, NavigableString +from core.chain.llm_chain import LLMChain +from core.data_loader import file_extractor +from core.data_loader.file_extractor import FileExtractor +from core.entities.application_entities import ModelConfigEntity from langchain.chains import RefineDocumentsChain from langchain.chains.summarize import refine_prompts from langchain.schema import Document @@ -20,11 +24,6 @@ from newspaper import Article from pydantic import BaseModel, Field from regex import regex -from core.chain.llm_chain import LLMChain -from core.data_loader import file_extractor -from core.data_loader.file_extractor import FileExtractor -from core.entities.application_entities import ModelConfigEntity - FULL_TEMPLATE = """ TITLE: {title} AUTHORS: {authors} diff --git a/api/core/vector_store/qdrant_vector_store.py b/api/core/vector_store/qdrant_vector_store.py index e4f6c2c78f..06544766f3 100644 --- a/api/core/vector_store/qdrant_vector_store.py +++ b/api/core/vector_store/qdrant_vector_store.py @@ -1,10 +1,9 @@ -from typing import cast, Any - -from langchain.schema import Document -from qdrant_client.http.models import Filter, PointIdsList, FilterSelector -from qdrant_client.local.qdrant_local import QdrantLocal +from typing import Any, cast from core.vector_store.vector.qdrant import Qdrant +from langchain.schema import Document +from qdrant_client.http.models import Filter, FilterSelector, PointIdsList +from qdrant_client.local.qdrant_local import QdrantLocal class QdrantVectorStore(Qdrant): diff --git a/api/core/vector_store/vector/milvus.py b/api/core/vector_store/vector/milvus.py index 013172d826..5c2b11cff8 100644 --- a/api/core/vector_store/vector/milvus.py +++ b/api/core/vector_store/vector/milvus.py @@ -2,11 +2,10 @@ from __future__ import annotations import logging -from typing import Any, Iterable, List, Optional, Tuple, Union, Sequence +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union from uuid import uuid4 import numpy as np - from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores.base import VectorStore @@ -235,13 +234,7 @@ class Milvus(VectorStore): def _create_collection( self, embeddings: list, metadatas: Optional[list[dict]] = None ) -> None: - from pymilvus import ( - Collection, - CollectionSchema, - DataType, - FieldSchema, - MilvusException, - ) + from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException from pymilvus.orm.types import infer_dtype_bydata # Determine embedding dim diff --git a/api/core/vector_store/vector/qdrant.py b/api/core/vector_store/vector/qdrant.py index 33ba0908dd..00d6aac536 100644 --- a/api/core/vector_store/vector/qdrant.py +++ b/api/core/vector_store/vector/qdrant.py @@ -7,28 +7,14 @@ import uuid import warnings from itertools import islice from operator import itemgetter -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Type, Union import numpy as np - from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance -from qdrant_client.http.models import PayloadSchemaType, FilterSelector, TextIndexParams, TokenizerType, TextIndexType +from qdrant_client.http.models import FilterSelector, PayloadSchemaType, TextIndexParams, TextIndexType, TokenizerType if TYPE_CHECKING: from qdrant_client import grpc # noqa diff --git a/api/core/vector_store/vector/weaviate.py b/api/core/vector_store/vector/weaviate.py index e00ca6978b..3f7ff58ac4 100644 --- a/api/core/vector_store/vector/weaviate.py +++ b/api/core/vector_store/vector/weaviate.py @@ -6,7 +6,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type from uuid import uuid4 import numpy as np - from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.utils import get_from_dict_or_env diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 3aa5db286f..88d226d303 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -1,9 +1,9 @@ -from .create_installed_app_when_app_created import handle -from .delete_installed_app_when_app_deleted import handle -from .clean_when_document_deleted import handle from .clean_when_dataset_deleted import handle -from .update_app_dataset_join_when_app_model_config_updated import handle -from .generate_conversation_name_when_first_message_created import handle +from .clean_when_document_deleted import handle from .create_document_index import handle +from .create_installed_app_when_app_created import handle from .deduct_quota_when_messaeg_created import handle +from .delete_installed_app_when_app_deleted import handle +from .generate_conversation_name_when_first_message_created import handle +from .update_app_dataset_join_when_app_model_config_updated import handle from .update_provider_last_used_at_when_messaeg_created import handle diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index a5f5c4d8f4..058a9b4b00 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -1,16 +1,15 @@ -from events.dataset_event import dataset_was_deleted -from events.event_handlers.document_index_event import document_index_created import datetime import logging import time import click from celery import shared_task -from werkzeug.exceptions import NotFound - -from core.indexing_runner import IndexingRunner, DocumentIsPausedException +from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from events.dataset_event import dataset_was_deleted +from events.event_handlers.document_index_event import document_index_created from extensions.ext_database import db from models.dataset import Document +from werkzeug.exceptions import NotFound @document_index_created.connect diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index d3f69fab16..848c79802b 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -2,7 +2,7 @@ from core.entities.application_entities import ApplicationGenerateEntity from core.entities.provider_entities import QuotaUnit from events.message_event import message_was_created from extensions.ext_database import db -from models.provider import ProviderType, Provider +from models.provider import Provider, ProviderType @message_was_created.connect diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index ec3e360f09..b27105f4d0 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,6 +1,6 @@ from datetime import timedelta -from celery import Task, Celery +from celery import Celery, Task from flask import Flask diff --git a/api/extensions/ext_hosting_provider.py b/api/extensions/ext_hosting_provider.py index 49e2fcb0c7..5752ec7f4c 100644 --- a/api/extensions/ext_hosting_provider.py +++ b/api/extensions/ext_hosting_provider.py @@ -1,6 +1,5 @@ -from flask import Flask - from core.hosting_configuration import HostingConfiguration +from flask import Flask hosting_configuration = HostingConfiguration() diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index f00b300808..c758ccb7df 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -1,5 +1,5 @@ import redis -from redis.connection import SSLConnection, Connection +from redis.connection import Connection, SSLConnection redis_client = redis.Redis() diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index f591e8173e..2390c5fa69 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -1,7 +1,7 @@ import os import shutil from contextlib import closing -from typing import Union, Generator +from typing import Generator, Union import boto3 from botocore.exceptions import ClientError diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index 749e9900de..2ccc9ddfe0 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -1,5 +1,4 @@ from flask_restful import fields - from libs.helper import TimestampField diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index e9db1aca50..f303c37864 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,5 +1,4 @@ from flask_restful import fields - from libs.helper import TimestampField app_detail_kernel_fields = { diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 661cb0ea73..5ab73115d8 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,5 +1,4 @@ from flask_restful import fields - from libs.helper import TimestampField diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 6f3c920c85..a7035018df 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -1,5 +1,4 @@ from flask_restful import fields - from libs.helper import TimestampField integrate_icon_fields = { diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 94d905eafe..bf115659ef 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -1,6 +1,5 @@ -from flask_restful import fields - from fields.dataset_fields import dataset_fields +from flask_restful import fields from libs.helper import TimestampField document_fields = { diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 2ef379dabc..5f2322003a 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,5 +1,4 @@ from flask_restful import fields - from libs.helper import TimestampField upload_config_fields = { diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 541e56a378..bb8805417e 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -1,5 +1,4 @@ from flask_restful import fields - from libs.helper import TimestampField document_fields = { diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 79af54bdd6..95b2088c2d 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,5 +1,4 @@ from flask_restful import fields - from libs.helper import TimestampField app_fields = { diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index df0a1104fe..5995abbcfa 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,7 +1,6 @@ -from flask_restful import fields - -from libs.helper import TimestampField from fields.conversation_fields import message_file_fields +from flask_restful import fields +from libs.helper import TimestampField feedback_fields = { 'rating': fields.String diff --git a/api/libs/external_api.py b/api/libs/external_api.py index b5cc8fb9c5..901ce89690 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -1,7 +1,7 @@ import re import sys -from flask import got_request_exception, current_app +from flask import current_app, got_request_exception from flask_restful import Api, http_status_message from werkzeug.datastructures import Headers from werkzeug.exceptions import HTTPException diff --git a/api/libs/helper.py b/api/libs/helper.py index b306fee2a7..b3675a635a 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -1,12 +1,12 @@ # -*- coding:utf-8 -*- +import random import re +import string import subprocess import uuid from datetime import datetime from hashlib import sha256 from zoneinfo import available_timezones -import random -import string from flask_restful import fields diff --git a/api/libs/login.py b/api/libs/login.py index c37f761f37..06c6a837af 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,18 +1,14 @@ import os from functools import wraps -from flask import current_app -from flask import g -from flask import has_request_context -from flask import request, session +from extensions.ext_database import db +from flask import current_app, g, has_request_context, request, session from flask_login import user_logged_in from flask_login.config import EXEMPT_METHODS +from models.account import Account, Tenant, TenantAccountJoin from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy -from extensions.ext_database import db -from models.account import Account, Tenant, TenantAccountJoin - #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user current_user = LocalProxy(lambda: _get_user()) diff --git a/api/libs/oauth.py b/api/libs/oauth.py index c89ac6d653..2a91d9941a 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -3,9 +3,8 @@ import urllib.parse from dataclasses import dataclass import requests -from flask_login import current_user - from extensions.ext_database import db +from flask_login import current_user from models.source import DataSourceBinding diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 740b96924c..1cf84e808a 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -2,9 +2,8 @@ import json import urllib.parse import requests -from flask_login import current_user - from extensions.ext_database import db +from flask_login import current_user from models.source import DataSourceBinding diff --git a/api/libs/passport.py b/api/libs/passport.py index c3bd9e566f..1116cb7293 100644 --- a/api/libs/passport.py +++ b/api/libs/passport.py @@ -1,7 +1,9 @@ # -*- coding:utf-8 -*- import jwt -from werkzeug.exceptions import Unauthorized from flask import current_app +from werkzeug.exceptions import Unauthorized + + class PassportService: def __init__(self): self.sk = current_app.config.get('SECRET_KEY') diff --git a/api/libs/rsa.py b/api/libs/rsa.py index 6aaa811e69..80ee7c18ae 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -1,10 +1,9 @@ # -*- coding:utf-8 -*- import hashlib -from Crypto.Cipher import PKCS1_OAEP, AES +from Crypto.Cipher import AES, PKCS1_OAEP from Crypto.PublicKey import RSA from Crypto.Random import get_random_bytes - from extensions.ext_redis import redis_client from extensions.ext_storage import storage diff --git a/api/migrations/env.py b/api/migrations/env.py index 0ac25ee989..18485c1885 100644 --- a/api/migrations/env.py +++ b/api/migrations/env.py @@ -1,9 +1,8 @@ import logging from logging.config import fileConfig -from flask import current_app - from alembic import context +from flask import current_app # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/api/migrations/versions/16fa53d9faec_add_provider_model_support.py b/api/migrations/versions/16fa53d9faec_add_provider_model_support.py index 92b1fba9c7..6791cf4578 100644 --- a/api/migrations/versions/16fa53d9faec_add_provider_model_support.py +++ b/api/migrations/versions/16fa53d9faec_add_provider_model_support.py @@ -5,8 +5,8 @@ Revises: 8d2d099ceb74 Create Date: 2023-08-06 16:57:51.248337 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py b/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py index ac10564299..adff497e0b 100644 --- a/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py +++ b/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py @@ -5,8 +5,8 @@ Revises: 88072f0caa04 Create Date: 2024-01-02 07:18:43.887428 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py index 2a8a9abcb4..9816e92dd1 100644 --- a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py +++ b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py @@ -5,8 +5,8 @@ Revises: 714aafe25d39 Create Date: 2023-12-14 11:26:12.287264 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py b/api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py index 0cf3c2c171..e933623d1c 100644 --- a/api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py +++ b/api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py @@ -5,9 +5,8 @@ Revises: d3d503a3471c Create Date: 2023-07-07 12:11:29.156057 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = '2beac44e5f5f' diff --git a/api/migrations/versions/2c8af9671032_add_qa_document_language.py b/api/migrations/versions/2c8af9671032_add_qa_document_language.py index 5ad90fa31c..1f0c145446 100644 --- a/api/migrations/versions/2c8af9671032_add_qa_document_language.py +++ b/api/migrations/versions/2c8af9671032_add_qa_document_language.py @@ -5,9 +5,8 @@ Revises: 8d2d099ceb74 Create Date: 2023-08-01 18:57:27.294973 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = '2c8af9671032' diff --git a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py index 23ecabe9d2..b06a3530b8 100644 --- a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py +++ b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py @@ -5,8 +5,8 @@ Revises: 6e2cfb077b04 Create Date: 2023-09-22 15:41:01.243183 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py b/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py index 286e2d3e09..b47dd3c8ab 100644 --- a/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py +++ b/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py @@ -5,9 +5,8 @@ Revises: e1901f623fd0 Create Date: 2023-12-13 04:39:59.302971 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = '46976cc39132' diff --git a/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py b/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py index 0753c26d1c..178bd24e3c 100644 --- a/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py +++ b/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py @@ -5,9 +5,8 @@ Revises: 853f9b9cd3b6 Create Date: 2023-08-28 20:58:50.077056 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = '4bcffcd64aa4' diff --git a/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py b/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py index 182db6ccc3..c0f4af5a00 100644 --- a/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py +++ b/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py @@ -5,9 +5,8 @@ Revises: bf0aec5ba2cf Create Date: 2023-08-11 14:38:15.499460 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = '5022897aaceb' diff --git a/api/migrations/versions/614f77cecc48_add_last_active_at.py b/api/migrations/versions/614f77cecc48_add_last_active_at.py index d0509d2967..182f8f89f1 100644 --- a/api/migrations/versions/614f77cecc48_add_last_active_at.py +++ b/api/migrations/versions/614f77cecc48_add_last_active_at.py @@ -5,9 +5,8 @@ Revises: a45f4dfde53b Create Date: 2023-06-15 13:33:00.357467 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = '614f77cecc48' diff --git a/api/migrations/versions/64b051264f32_init.py b/api/migrations/versions/64b051264f32_init.py index d2160c128d..8c45ae898d 100644 --- a/api/migrations/versions/64b051264f32_init.py +++ b/api/migrations/versions/64b051264f32_init.py @@ -5,8 +5,8 @@ Revises: Create Date: 2023-05-13 14:26:59.085018 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py b/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py index 255dddeec6..da27dd4426 100644 --- a/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py +++ b/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py @@ -5,8 +5,8 @@ Revises: 4bcffcd64aa4 Create Date: 2023-09-06 16:51:27.385844 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py index d37570e1bd..4fa322f693 100644 --- a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py +++ b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py @@ -5,8 +5,8 @@ Revises: 77e83833755c Create Date: 2023-09-13 22:16:48.027810 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py index 5e0eba623b..498b46e3c4 100644 --- a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py +++ b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py @@ -5,9 +5,8 @@ Revises: f2a6fc85e260 Create Date: 2023-12-14 06:38:02.972527 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = '714aafe25d39' diff --git a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py index 405e9520e1..c5d8c3d88d 100644 --- a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py +++ b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py @@ -5,9 +5,8 @@ Revises: 6dcb43972bdc Create Date: 2023-09-06 17:26:40.311927 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = '77e83833755c' diff --git a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py index 7be203b48a..881ffec61d 100644 --- a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py +++ b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py @@ -5,8 +5,8 @@ Revises: 2beac44e5f5f Create Date: 2023-07-10 10:26:50.074515 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py b/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py index f3c13095a6..dfb59839b4 100644 --- a/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py +++ b/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py @@ -5,8 +5,8 @@ Revises: e8883b0148c9 Create Date: 2023-08-19 17:01:57.471562 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py index 6cf4d744d4..f7625bff8c 100644 --- a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py +++ b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py @@ -5,9 +5,8 @@ Revises: fca025d3b60f Create Date: 2023-12-14 07:36:50.705362 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = '88072f0caa04' diff --git a/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py b/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py index e0915a5fb1..849103b071 100644 --- a/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py +++ b/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py @@ -5,8 +5,8 @@ Revises: a5b56fb053ef Create Date: 2023-07-18 15:25:15.293438 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py index 7aed3c5e6c..01d5631510 100644 --- a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py +++ b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py @@ -5,8 +5,8 @@ Revises: a9836e3baeee Create Date: 2023-11-09 11:39:00.006432 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py index 57b28e707f..207a9c841f 100644 --- a/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py +++ b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py @@ -5,8 +5,8 @@ Revises: b3a09c049e8e Create Date: 2023-10-27 13:05:58.901858 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py index c7e3e801ec..c7a98b4ac6 100644 --- a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py +++ b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py @@ -5,9 +5,8 @@ Revises: 64b051264f32 Create Date: 2023-05-17 17:29:01.060435 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = '9f4e3427ea84' diff --git a/api/migrations/versions/a45f4dfde53b_add_language_to_recommend_apps.py b/api/migrations/versions/a45f4dfde53b_add_language_to_recommend_apps.py index 51e364e4d6..3014978110 100644 --- a/api/migrations/versions/a45f4dfde53b_add_language_to_recommend_apps.py +++ b/api/migrations/versions/a45f4dfde53b_add_language_to_recommend_apps.py @@ -5,9 +5,8 @@ Revises: 9f4e3427ea84 Create Date: 2023-05-25 17:50:32.052335 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = 'a45f4dfde53b' diff --git a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py index b197f8f5a8..acb6812434 100644 --- a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py +++ b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py @@ -5,9 +5,8 @@ Revises: d3d503a3471c Create Date: 2023-07-06 17:55:20.894149 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = 'a5b56fb053ef' diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py index 9b452f75ee..cf296628a9 100644 --- a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py +++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py @@ -5,8 +5,8 @@ Revises: 968fff4c0ab9 Create Date: 2023-11-02 04:04:57.609485 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/ab23c11305d4_add_dataset_query_variable_at_app_model_.py b/api/migrations/versions/ab23c11305d4_add_dataset_query_variable_at_app_model_.py index eec2dc64b3..eee41bf4e0 100644 --- a/api/migrations/versions/ab23c11305d4_add_dataset_query_variable_at_app_model_.py +++ b/api/migrations/versions/ab23c11305d4_add_dataset_query_variable_at_app_model_.py @@ -5,8 +5,8 @@ Revises: 6e2cfb077b04 Create Date: 2023-09-26 12:22:59.044088 """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = 'ab23c11305d4' diff --git a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py index cbb04bb01e..5682eff030 100644 --- a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py +++ b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py @@ -5,8 +5,8 @@ Revises: 2e9819ca5b28 Create Date: 2023-10-10 15:23:23.395420 """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = 'b3a09c049e8e' diff --git a/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py b/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py index aa9f74fe38..dfa1517462 100644 --- a/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py +++ b/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py @@ -5,8 +5,8 @@ Revises: e35ed59becda Create Date: 2023-08-10 00:03:44.273430 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/d3d503a3471c_add_is_deleted_to_conversations.py b/api/migrations/versions/d3d503a3471c_add_is_deleted_to_conversations.py index 4b8c702c4e..89355e57ad 100644 --- a/api/migrations/versions/d3d503a3471c_add_is_deleted_to_conversations.py +++ b/api/migrations/versions/d3d503a3471c_add_is_deleted_to_conversations.py @@ -5,9 +5,8 @@ Revises: e32f6ccb87c6 Create Date: 2023-06-27 19:13:30.897981 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = 'd3d503a3471c' diff --git a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py index 94200de9d4..32902c8eb0 100644 --- a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py +++ b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py @@ -5,8 +5,8 @@ Revises: fca025d3b60f Create Date: 2023-12-12 06:58:41.054544 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py b/api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py index 444f224fc4..3d7dd1fabf 100644 --- a/api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py +++ b/api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py @@ -5,8 +5,8 @@ Revises: a45f4dfde53b Create Date: 2023-06-06 19:58:33.103819 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/e35ed59becda_modify_quota_limit_field_type.py b/api/migrations/versions/e35ed59becda_modify_quota_limit_field_type.py index e9056d57f9..627366b36d 100644 --- a/api/migrations/versions/e35ed59becda_modify_quota_limit_field_type.py +++ b/api/migrations/versions/e35ed59becda_modify_quota_limit_field_type.py @@ -5,9 +5,8 @@ Revises: 16fa53d9faec Create Date: 2023-08-09 22:20:31.577953 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = 'e35ed59becda' diff --git a/api/migrations/versions/e8883b0148c9_add_dataset_model_name.py b/api/migrations/versions/e8883b0148c9_add_dataset_model_name.py index 67eaf35e52..875683d68e 100644 --- a/api/migrations/versions/e8883b0148c9_add_dataset_model_name.py +++ b/api/migrations/versions/e8883b0148c9_add_dataset_model_name.py @@ -5,9 +5,8 @@ Revises: 2c8af9671032 Create Date: 2023-08-15 20:54:58.936787 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = 'e8883b0148c9' diff --git a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py index b85a8fd023..dc9392a92c 100644 --- a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py +++ b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py @@ -5,8 +5,8 @@ Revises: 46976cc39132 Create Date: 2023-12-13 11:09:29.329584 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py b/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py index c16781c15d..1f8250c3eb 100644 --- a/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py +++ b/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py @@ -5,8 +5,8 @@ Revises: b3a09c049e8e Create Date: 2023-11-03 13:08:23.246396 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/models/account.py b/api/models/account.py index 5a86b4f8e0..81d56d974e 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,10 +1,10 @@ -import json import enum +import json from math import e from typing import List -from flask_login import UserMixin from extensions.ext_database import db +from flask_login import UserMixin from sqlalchemy.dialects.postgresql import UUID diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index e34cfb8f7b..200675d766 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,8 +1,7 @@ import enum -from sqlalchemy.dialects.postgresql import UUID - from extensions.ext_database import db +from sqlalchemy.dialects.postgresql import UUID class APIBasedExtensionPoint(enum.Enum): diff --git a/api/models/dataset.py b/api/models/dataset.py index 908bcda158..5e67b2b8b8 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -2,12 +2,11 @@ import json import pickle from json import JSONDecodeError -from sqlalchemy import func -from sqlalchemy.dialects.postgresql import UUID, JSONB - from extensions.ext_database import db from models.account import Account from models.model import App, UploadFile +from sqlalchemy import func +from sqlalchemy.dialects.postgresql import JSONB, UUID class Dataset(db.Model): diff --git a/api/models/model.py b/api/models/model.py index df27edb9a5..1a23cb5dc9 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,13 +1,13 @@ import json +from core.file.upload_file_parser import UploadFileParser +from extensions.ext_database import db from flask import current_app, request from flask_login import UserMixin +from libs.helper import generate_string from sqlalchemy import Float from sqlalchemy.dialects.postgresql import UUID -from core.file.upload_file_parser import UploadFileParser -from libs.helper import generate_string -from extensions.ext_database import db from .account import Account, Tenant diff --git a/api/models/provider.py b/api/models/provider.py index 4c9fd793cc..514ea47cff 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,8 +1,7 @@ from enum import Enum -from sqlalchemy.dialects.postgresql import UUID - from extensions.ext_database import db +from sqlalchemy.dialects.postgresql import UUID class ProviderType(Enum): diff --git a/api/models/source.py b/api/models/source.py index c7c04075bc..9923e38b4b 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,7 +1,6 @@ -from sqlalchemy.dialects.postgresql import UUID - from extensions.ext_database import db -from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import JSONB, UUID + class DataSourceBinding(db.Model): __tablename__ = 'data_source_bindings' diff --git a/api/models/task.py b/api/models/task.py index d85cf16d7c..fd4105b2ea 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,7 +1,8 @@ -from extensions.ext_database import db -from celery import states from datetime import datetime +from celery import states +from extensions.ext_database import db + class CeleryTask(db.Model): """Task result/status.""" diff --git a/api/models/tool.py b/api/models/tool.py index ac866e20a4..0e25659980 100644 --- a/api/models/tool.py +++ b/api/models/tool.py @@ -1,9 +1,8 @@ import json from enum import Enum -from sqlalchemy.dialects.postgresql import UUID - from extensions.ext_database import db +from sqlalchemy.dialects.postgresql import UUID class ToolProviderName(Enum): diff --git a/api/models/web.py b/api/models/web.py index b2466430b9..2957703a20 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,7 +1,6 @@ -from sqlalchemy.dialects.postgresql import UUID - from extensions.ext_database import db from models.model import Message +from sqlalchemy.dialects.postgresql import UUID class SavedMessage(db.Model): diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 1caed9e02e..53d9500bab 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -1,11 +1,12 @@ -import app import datetime import time + +import app import click -from flask import current_app -from werkzeug.exceptions import NotFound from extensions.ext_database import db +from flask import current_app from models.dataset import Embedding +from werkzeug.exceptions import NotFound @app.celery.task(queue='dataset') diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index eb95cc5da2..f5ba46463f 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -1,13 +1,14 @@ -import logging -import app import datetime +import logging import time + +import app import click -from flask import current_app -from werkzeug.exceptions import NotFound from core.index.index import IndexBuilder from extensions.ext_database import db -from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding +from flask import current_app +from models.dataset import Dataset, DatasetCollectionBinding, DatasetQuery, Document +from werkzeug.exceptions import NotFound @app.celery.task(queue='dataset') diff --git a/api/services/account_service.py b/api/services/account_service.py index 58e9930dce..a3dade998e 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -6,23 +6,23 @@ import secrets import uuid from datetime import datetime, timedelta from hashlib import sha256 -from typing import Optional, Dict, Any - -from werkzeug.exceptions import Forbidden, Unauthorized -from flask import session, current_app -from sqlalchemy import func +from typing import Any, Dict, Optional from events.tenant_event import tenant_was_created from extensions.ext_redis import redis_client -from services.errors.account import AccountLoginError, CurrentPasswordIncorrectError, LinkAccountIntegrateError, \ - TenantNotFound, AccountNotLinkTenantError, InvalidActionError, CannotOperateSelfError, MemberNotInTenantError, \ - RoleAlreadyAssignedError, NoPermissionError, AccountRegisterError, AccountAlreadyInTenantError +from flask import current_app, session from libs.helper import get_remote_ip +from libs.passport import PassportService from libs.password import compare_password, hash_password from libs.rsa import generate_key_pair -from libs.passport import PassportService from models.account import * +from services.errors.account import (AccountAlreadyInTenantError, AccountLoginError, AccountNotLinkTenantError, + AccountRegisterError, CannotOperateSelfError, CurrentPasswordIncorrectError, + InvalidActionError, LinkAccountIntegrateError, MemberNotInTenantError, + NoPermissionError, RoleAlreadyAssignedError, TenantNotFound) +from sqlalchemy import func from tasks.mail_invite_member_task import send_invite_member_mail_task +from werkzeug.exceptions import Forbidden, Unauthorized def _create_tenant_for_account(account) -> Tenant: diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 95c9c615ea..d3cd911125 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,9 +1,15 @@ import copy +from core.prompt.advanced_prompt_templates import (BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CONTEXT, + CHAT_APP_CHAT_PROMPT_CONFIG, CHAT_APP_COMPLETION_PROMPT_CONFIG, + COMPLETION_APP_CHAT_PROMPT_CONFIG, + COMPLETION_APP_COMPLETION_PROMPT_CONFIG, CONTEXT) from core.prompt.prompt_transform import AppMode -from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \ - BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT + class AdvancedPromptTemplateService: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index b84e87bf44..1ffe44910b 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -3,20 +3,19 @@ import json import uuid import pandas as pd -from flask_login import current_user -from sqlalchemy import or_ -from werkzeug.datastructures import FileStorage -from werkzeug.exceptions import NotFound - from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.model import MessageAnnotation, Message, App, AppAnnotationHitHistory, AppAnnotationSetting +from flask_login import current_user +from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation +from sqlalchemy import or_ from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task -from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task -from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task -from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task -from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task +from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task +from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task +from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task +from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import NotFound class AppAnnotationService: diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index d4e7d5be3d..8441bbedb3 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -1,7 +1,7 @@ +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor +from core.helper.encrypter import decrypt_token, encrypt_token from extensions.ext_database import db from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from core.helper.encrypter import encrypt_token, decrypt_token -from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor class APIBasedExtensionService: diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 45ded1b2a6..bf7dfab747 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,17 +1,16 @@ import re import uuid +from core.agent.agent_executor import PlanningStrategy from core.external_data_tool.factory import ExternalDataToolFactory from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers import model_provider_factory from core.moderation.factory import ModerationFactory from core.prompt.prompt_transform import AppMode -from core.agent.agent_executor import PlanningStrategy from core.provider_manager import ProviderManager from models.account import Account from services.dataset_service import DatasetService - SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 2f4a73c3ac..8d9a1e3b89 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -1,9 +1,10 @@ import io -from werkzeug.datastructures import FileStorage from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError +from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) +from werkzeug.datastructures import FileStorage FILE_SIZE = 15 FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 diff --git a/api/services/billing_service.py b/api/services/billing_service.py index d6761680ee..d798a75bd2 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,7 +1,6 @@ import os import requests - from extensions.ext_database import db from models.account import TenantAccountJoin diff --git a/api/services/completion_service.py b/api/services/completion_service.py index e0f7cd833d..069aa28537 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -1,18 +1,17 @@ import json -from typing import Generator, Union, Any - -from sqlalchemy import and_ +from typing import Any, Generator, Union from core.application_manager import ApplicationManager from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db -from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message +from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message from services.app_model_config_service import AppModelConfigService from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError -from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError +from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError from services.errors.message import MessageNotExistsError +from sqlalchemy import and_ class CompletionService: diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 1dcbb52867..ac3df380b2 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,10 +1,10 @@ -from typing import Union, Optional +from typing import Optional, Union from core.generator.llm_generator import LLMGenerator -from libs.infinite_scroll_pagination import InfiniteScrollPagination from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.model import Conversation, App, EndUser, Message +from models.model import App, Conversation, EndUser, Message from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.errors.message import MessageNotExistsError diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f71f7fe1f8..236ffe7008 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1,29 +1,26 @@ +import datetime import json import logging -import datetime -import time import random +import time import uuid -from typing import Optional, List, cast +from typing import List, Optional, cast -from flask import current_app -from sqlalchemy import func - -from core.index.index import IndexBuilder from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.index.index import IndexBuilder from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from extensions.ext_redis import redis_client -from flask_login import current_user - from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db +from extensions.ext_redis import redis_client +from flask import current_app +from flask_login import current_user from libs import helper from models.account import Account -from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment, \ - DatasetCollectionBinding +from models.dataset import (AppDatasetJoin, Dataset, DatasetCollectionBinding, DatasetProcessRule, DatasetQuery, + Document, DocumentSegment) from models.model import UploadFile from models.source import DataSourceBinding from services.errors.account import NoPermissionError @@ -31,12 +28,13 @@ from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError from services.vector_service import VectorService +from sqlalchemy import func from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task +from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task from tasks.recover_document_indexing_task import recover_document_indexing_task -from tasks.delete_segment_from_index_task import delete_segment_from_index_task class DatasetService: diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index d0cd71af4d..58f00135dc 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,16 +1,16 @@ from enum import Enum from typing import Optional -from flask import current_app -from pydantic import BaseModel - from core.entities.model_entities import ModelStatus, ModelWithProviderEntity from core.entities.provider_entities import QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType, ProviderModel -from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderCredentialSchema, \ - ModelCredentialSchema, ProviderHelpEntity, SimpleProviderEntity -from models.provider import ProviderType, ProviderQuotaType +from core.model_runtime.entities.provider_entities import (ConfigurateMethod, ModelCredentialSchema, + ProviderCredentialSchema, ProviderHelpEntity, + SimpleProviderEntity) +from flask import current_app +from models.provider import ProviderQuotaType, ProviderType +from pydantic import BaseModel class CustomConfigurationStatus(Enum): diff --git a/api/services/feature_service.py b/api/services/feature_service.py index bf7b378f38..75feaf7800 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -1,6 +1,5 @@ -from pydantic import BaseModel from flask import current_app - +from pydantic import BaseModel from services.billing_service import BillingService diff --git a/api/services/file_service.py b/api/services/file_service.py index fb61bfa9be..14d71fe546 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -3,18 +3,17 @@ import hashlib import uuid from typing import Generator, Tuple, Union -from flask import current_app -from flask_login import current_user -from werkzeug.datastructures import FileStorage -from werkzeug.exceptions import NotFound - from core.data_loader.file_extractor import FileExtractor from core.file.upload_file_parser import UploadFileParser -from extensions.ext_storage import storage from extensions.ext_database import db +from extensions.ext_storage import storage +from flask import current_app +from flask_login import current_user from models.account import Account -from models.model import UploadFile, EndUser +from models.model import EndUser, UploadFile from services.errors.file import FileTooLargeError, UnsupportedFileTypeError +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import NotFound ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv', 'jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index ea4e8ecf0c..0f38241948 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -4,19 +4,18 @@ import time from typing import List import numpy as np -from flask import current_app -from langchain.embeddings.base import Embeddings -from langchain.schema import Document -from sklearn.manifold import TSNE - from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rerank.rerank import RerankRunner from extensions.ext_database import db +from flask import current_app +from langchain.embeddings.base import Embeddings +from langchain.schema import Document from models.account import Account -from models.dataset import Dataset, DocumentSegment, DatasetQuery +from models.dataset import Dataset, DatasetQuery, DocumentSegment from services.retrieval_service import RetrievalService +from sklearn.manifold import TSNE default_retrieval_model = { 'search_method': 'semantic_search', diff --git a/api/services/message_service.py b/api/services/message_service.py index a2cfb9a7f2..79feb1c669 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,19 +1,19 @@ import json -from typing import Optional, Union, List +from typing import List, Optional, Union from core.generator.llm_generator import LLMGenerator from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from libs.infinite_scroll_pagination import InfiniteScrollPagination from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.model import App, EndUser, Message, MessageFeedback, AppModelConfig +from models.model import App, AppModelConfig, EndUser, Message, MessageFeedback from services.conversation_service import ConversationService from services.errors.app_model_config import AppModelConfigBrokenError -from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError -from services.errors.message import FirstMessageNotExistsError, MessageNotExistsError, LastMessageNotExistsError, \ - SuggestedQuestionsAfterAnswerDisabledError +from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError +from services.errors.message import (FirstMessageNotExistsError, LastMessageNotExistsError, MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError) class MessageService: diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 422a59b0cc..906ef57a39 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,20 +1,21 @@ import logging import mimetypes import os -from typing import Optional, cast, Tuple +from typing import Optional, Tuple, cast import requests -from flask import current_app - from core.entities.model_entities import ModelStatus from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.provider_manager import ProviderManager +from flask import current_app from models.provider import ProviderType -from services.entities.model_provider_entities import ProviderResponse, CustomConfigurationResponse, \ - SystemConfigurationResponse, CustomConfigurationStatus, ProviderWithModelsResponse, ModelResponse, \ - DefaultModelResponse, ModelWithProviderEntityResponse, SimpleProviderEntityResponse +from services.entities.model_provider_entities import (CustomConfigurationResponse, CustomConfigurationStatus, + DefaultModelResponse, ModelResponse, + ModelWithProviderEntityResponse, ProviderResponse, + ProviderWithModelsResponse, SimpleProviderEntityResponse, + SystemConfigurationResponse) logger = logging.getLogger(__name__) diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index d0be7bca80..d472f8cfbc 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -1,6 +1,6 @@ -from models.model import AppModelConfig, App from core.moderation.factory import ModerationFactory, ModerationOutputsResult from extensions.ext_database import db +from models.model import App, AppModelConfig class ModerationService: diff --git a/api/services/retrieval_service.py b/api/services/retrieval_service.py index 6e325b0089..2efa0dbee4 100644 --- a/api/services/retrieval_service.py +++ b/api/services/retrieval_service.py @@ -1,12 +1,13 @@ from typing import Optional -from flask import current_app, Flask -from langchain.embeddings.base import Embeddings + from core.index.vector_index.vector_index import VectorIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rerank.rerank import RerankRunner from extensions.ext_database import db +from flask import Flask, current_app +from langchain.embeddings.base import Embeddings from models.dataset import Dataset default_retrieval_model = { diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 21ed0f7d64..f1113c1505 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -1,7 +1,7 @@ from typing import Optional, Union -from libs.infinite_scroll_pagination import InfiniteScrollPagination from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.model import App, EndUser from models.web import SavedMessage diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 45bf611fd4..ee06bd175a 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,10 +1,8 @@ -from typing import Optional, List - -from langchain.schema import Document +from typing import List, Optional from core.index.index import IndexBuilder - +from langchain.schema import Document from models.dataset import Dataset, DocumentSegment diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 5f9fc883f0..06e3f6fd53 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -1,7 +1,7 @@ from typing import Optional, Union -from libs.infinite_scroll_pagination import InfiniteScrollPagination from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.model import App, EndUser from models.web import PinnedConversation diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index e92ad603e4..1bdf9d8631 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,11 +1,10 @@ +from extensions.ext_database import db from flask import current_app from flask_login import current_user -from extensions.ext_database import db from models.account import Tenant, TenantAccountJoin, TenantAccountJoinRole from models.provider import Provider - -from services.feature_service import FeatureService from services.account_service import TenantService +from services.feature_service import FeatureService class WorkspaceService: diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 20d54d97d4..ea5c17b487 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,14 +4,13 @@ import time import click from celery import shared_task -from langchain.schema import Document -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import DocumentSegment +from langchain.schema import Document from models.dataset import Document as DatasetDocument +from models.dataset import DocumentSegment +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 620413dffb..4c0e13feb3 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -3,10 +3,8 @@ import time import click from celery import shared_task -from langchain.schema import Document - from core.index.index import IndexBuilder - +from langchain.schema import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 8c908bf97e..a7d026b15e 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -4,15 +4,14 @@ import time import click from celery import shared_task -from langchain.schema import Document -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client +from langchain.schema import Document from models.dataset import Dataset -from models.model import MessageAnnotation, App, AppAnnotationSetting +from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index ee665fbb92..6cb51eecdb 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -4,13 +4,12 @@ import time import click from celery import shared_task -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset -from models.model import MessageAnnotation, App, AppAnnotationSetting +from models.model import App, AppAnnotationSetting, MessageAnnotation +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 5a68b9285c..42c3b23836 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -4,15 +4,14 @@ import time import click from celery import shared_task -from langchain.schema import Document -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client +from langchain.schema import Document from models.dataset import Dataset -from models.model import MessageAnnotation, App, AppAnnotationSetting +from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index e477b8c2c8..c1ca161c9a 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -3,10 +3,8 @@ import time import click from celery import shared_task -from langchain.schema import Document - from core.index.index import IndexBuilder - +from langchain.schema import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 9463820d5a..1d5d966098 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -6,8 +6,6 @@ from typing import List, cast import click from celery import shared_task -from sqlalchemy import func - from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -15,7 +13,8 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper -from models.dataset import DocumentSegment, Dataset, Document +from models.dataset import Dataset, Document, DocumentSegment +from sqlalchemy import func @shared_task(queue='dataset') diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 8f5e37f49b..5813f38706 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -3,13 +3,12 @@ import time import click from celery import shared_task -from flask import current_app - from core.index.index import IndexBuilder from core.index.vector_index.vector_index import VectorIndex from extensions.ext_database import db -from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \ - AppDatasetJoin, Document +from flask import current_app +from models.dataset import (AppDatasetJoin, Dataset, DatasetKeywordTable, DatasetProcessRule, DatasetQuery, Document, + DocumentSegment) @shared_task(queue='dataset') diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 2738ede28a..1750eb80aa 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -3,10 +3,9 @@ import time import click from celery import shared_task - from core.index.index import IndexBuilder from extensions.ext_database import db -from models.dataset import DocumentSegment, Dataset +from models.dataset import Dataset, DocumentSegment @shared_task(queue='dataset') diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index d68b257c4c..46b066d7ba 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -4,10 +4,9 @@ from typing import List import click from celery import shared_task - from core.index.index import IndexBuilder from extensions.ext_database import db -from models.dataset import DocumentSegment, Dataset, Document +from models.dataset import Dataset, Document, DocumentSegment @shared_task(queue='dataset') diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 6e0af2d933..23e599cf03 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -1,17 +1,16 @@ import datetime import logging import time -from typing import Optional, List +from typing import List, Optional import click from celery import shared_task -from langchain.schema import Document -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client +from langchain.schema import Document from models.dataset import DocumentSegment +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 6a3b52a40b..d8a3b501ef 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -3,12 +3,12 @@ import time import click from celery import shared_task -from langchain.schema import Document - from core.index.index import IndexBuilder from extensions.ext_database import db -from models.dataset import DocumentSegment, Dataset +from langchain.schema import Document +from models.dataset import Dataset from models.dataset import Document as DatasetDocument +from models.dataset import DocumentSegment @shared_task(queue='dataset') diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index bb5a87410f..75776b2aa2 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -3,12 +3,11 @@ import time import click from celery import shared_task -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import DocumentSegment, Dataset, Document +from models.dataset import Dataset, Document, DocumentSegment +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 97f4fd0677..57c94c7fa1 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -3,12 +3,11 @@ import time import click from celery import shared_task -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 7cb388c667..319e8ddb0d 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -4,14 +4,13 @@ import time import click from celery import shared_task -from werkzeug.exceptions import NotFound - from core.data_loader.loader.notion import NotionLoader from core.index.index import IndexBuilder -from core.indexing_runner import IndexingRunner, DocumentIsPausedException +from core.indexing_runner import DocumentIsPausedException, IndexingRunner from extensions.ext_database import db -from models.dataset import Document, Dataset, DocumentSegment +from models.dataset import Dataset, Document, DocumentSegment from models.source import DataSourceBinding +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 31d083aeac..2ea6288059 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -4,11 +4,10 @@ import time import click from celery import shared_task -from werkzeug.exceptions import NotFound - -from core.indexing_runner import IndexingRunner, DocumentIsPausedException +from core.indexing_runner import DocumentIsPausedException, IndexingRunner from extensions.ext_database import db from models.dataset import Document +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index c3fbe7172c..54449662c3 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -4,12 +4,11 @@ import time import click from celery import shared_task -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder -from core.indexing_runner import IndexingRunner, DocumentIsPausedException +from core.indexing_runner import DocumentIsPausedException, IndexingRunner from extensions.ext_database import db -from models.dataset import Document, Dataset, DocumentSegment +from models.dataset import Dataset, Document, DocumentSegment +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 8dffd01520..ce450563d2 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -4,13 +4,12 @@ import time import click from celery import shared_task -from langchain.schema import Document -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client +from langchain.schema import Document from models.dataset import DocumentSegment +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index 34b65dbc37..562f5fdddd 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -3,8 +3,9 @@ import time import click from celery import shared_task -from flask import current_app, render_template from extensions.ext_mail import mail +from flask import current_app, render_template + @shared_task(queue='mail') def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str): diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index a9917da9b2..e1ed87a395 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -3,11 +3,10 @@ import time import click from celery import shared_task -from werkzeug.exceptions import NotFound - -from core.indexing_runner import IndexingRunner, DocumentIsPausedException +from core.indexing_runner import DocumentIsPausedException, IndexingRunner from extensions.ext_database import db from models.dataset import Document +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 36ec02e48e..6bb6e96261 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -3,12 +3,11 @@ import time import click from celery import shared_task -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import DocumentSegment, Document +from models.dataset import Document, DocumentSegment +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/update_segment_index_task.py b/api/tasks/update_segment_index_task.py index 40089ad3e4..1f6592a3e8 100644 --- a/api/tasks/update_segment_index_task.py +++ b/api/tasks/update_segment_index_task.py @@ -5,13 +5,12 @@ from typing import List, Optional import click from celery import shared_task -from langchain.schema import Document -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client +from langchain.schema import Document from models.dataset import DocumentSegment +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/update_segment_keyword_index_task.py b/api/tasks/update_segment_keyword_index_task.py index 284f73677c..8ae4b64137 100644 --- a/api/tasks/update_segment_keyword_index_task.py +++ b/api/tasks/update_segment_keyword_index_task.py @@ -5,13 +5,12 @@ from typing import List, Optional import click from celery import shared_task -from langchain.schema import Document -from werkzeug.exceptions import NotFound - from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client +from langchain.schema import Document from models.dataset import DocumentSegment +from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tests/integration_tests/model_runtime/__mock/anthropic.py b/api/tests/integration_tests/model_runtime/__mock/anthropic.py index 34127515a0..96fd8f2026 100644 --- a/api/tests/integration_tests/model_runtime/__mock/anthropic.py +++ b/api/tests/integration_tests/model_runtime/__mock/anthropic.py @@ -1,16 +1,14 @@ -import anthropic -from anthropic import Anthropic -from anthropic.resources.completions import Completions -from anthropic.types import completion_create_params, Completion -from anthropic._types import NOT_GIVEN, NotGiven, Headers, Query, Body - -from _pytest.monkeypatch import MonkeyPatch - -from typing import List, Union, Literal, Any, Generator -from time import sleep - -import pytest import os +from time import sleep +from typing import Any, Generator, List, Literal, Union + +import anthropic +import pytest +from _pytest.monkeypatch import MonkeyPatch +from anthropic import Anthropic +from anthropic._types import NOT_GIVEN, Body, Headers, NotGiven, Query +from anthropic.resources.completions import Completions +from anthropic.types import Completion, completion_create_params MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index 6a16586c83..4ac4dfe1f0 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,17 +1,15 @@ +from typing import Generator, List + +import google.generativeai.types.content_types as content_types +import google.generativeai.types.generation_types as generation_config_types +import google.generativeai.types.safety_types as safety_types +import pytest +from _pytest.monkeypatch import MonkeyPatch +from google.ai import generativelanguage as glm from google.generativeai import GenerativeModel +from google.generativeai.client import _ClientManager, configure from google.generativeai.types import GenerateContentResponse from google.generativeai.types.generation_types import BaseGenerateContentResponse -import google.generativeai.types.generation_types as generation_config_types -import google.generativeai.types.content_types as content_types -import google.generativeai.types.safety_types as safety_types -from google.generativeai.client import _ClientManager, configure - -from google.ai import generativelanguage as glm - -from typing import Generator, List -from _pytest.monkeypatch import MonkeyPatch - -import pytest current_api_key = '' diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py index 52f0322be4..e1e87748cd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -1,12 +1,10 @@ -from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass - -from huggingface_hub import InferenceClient - -from _pytest.monkeypatch import MonkeyPatch -from typing import List, Dict, Any +import os +from typing import Any, Dict, List import pytest -import os +from _pytest.monkeypatch import MonkeyPatch +from huggingface_hub import InferenceClient +from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py index a83ac628f5..56b7ee4bfe 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -1,11 +1,12 @@ +import re +from typing import Any, Generator, List, Literal, Optional, Union + +from _pytest.monkeypatch import MonkeyPatch from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse, Details, StreamDetails, Token +from huggingface_hub.inference._text_generation import (Details, StreamDetails, TextGenerationResponse, + TextGenerationStreamResponse, Token) from huggingface_hub.utils import BadRequestError -from typing import Literal, Optional, List, Generator, Union, Any -from _pytest.monkeypatch import MonkeyPatch - -import re class MockHuggingfaceChatClass(object): @staticmethod diff --git a/api/tests/integration_tests/model_runtime/__mock/openai.py b/api/tests/integration_tests/model_runtime/__mock/openai.py index d4b9de5c51..92fe30f4c9 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai.py @@ -1,22 +1,22 @@ -from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass -from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass -from tests.integration_tests.model_runtime.__mock.openai_remote import MockModelClass -from tests.integration_tests.model_runtime.__mock.openai_moderation import MockModerationClass -from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass -from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass -from openai.resources.completions import Completions -from openai.resources.chat import Completions as ChatCompletions -from openai.resources.models import Models -from openai.resources.moderations import Moderations -from openai.resources.audio.transcriptions import Transcriptions -from openai.resources.embeddings import Embeddings +import os +from typing import Callable, List, Literal +import pytest # import monkeypatch from _pytest.monkeypatch import MonkeyPatch -from typing import Literal, Callable, List +from openai.resources.audio.transcriptions import Transcriptions +from openai.resources.chat import Completions as ChatCompletions +from openai.resources.completions import Completions +from openai.resources.embeddings import Embeddings +from openai.resources.models import Models +from openai.resources.moderations import Moderations +from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass +from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass +from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass +from tests.integration_tests.model_runtime.__mock.openai_moderation import MockModerationClass +from tests.integration_tests.model_runtime.__mock.openai_remote import MockModelClass +from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass -import os -import pytest def mock_openai(monkeypatch: MonkeyPatch, methods: List[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]: """ diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index 03e4c14ed5..dbc061b952 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -1,27 +1,26 @@ -from openai import OpenAI -from openai.types import Completion as CompletionMessage -from openai._types import NotGiven, NOT_GIVEN -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam, \ - ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaFunctionCall,\ - Choice, ChoiceDelta, ChoiceDeltaToolCallFunction -from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice, ChatCompletion as _ChatCompletion -from openai.types.chat.chat_completion_message import FunctionCall, ChatCompletionMessage -from openai.types.chat.chat_completion_message_tool_call import Function -from openai.types.completion_usage import CompletionUsage -from openai.resources.chat.completions import Completions -from openai import AzureOpenAI +import re +from json import dumps, loads +from time import sleep, time +# import monkeypatch +from typing import Any, Generator, List, Literal, Optional, Union import openai.types.chat.completion_create_params as completion_create_params - -# import monkeypatch -from typing import List, Any, Generator, Union, Optional, Literal -from time import time, sleep -from json import dumps, loads - from core.model_runtime.errors.invoke import InvokeAuthorizationError +from openai import AzureOpenAI, OpenAI +from openai._types import NOT_GIVEN, NotGiven +from openai.resources.chat.completions import Completions +from openai.types import Completion as CompletionMessage +from openai.types.chat import (ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam, + ChatCompletionMessageToolCall, ChatCompletionToolChoiceOptionParam, + ChatCompletionToolParam) +from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion +from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice +from openai.types.chat.chat_completion_chunk import (Choice, ChoiceDelta, ChoiceDeltaFunctionCall, ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction) +from openai.types.chat.chat_completion_message import ChatCompletionMessage, FunctionCall +from openai.types.chat.chat_completion_message_tool_call import Function +from openai.types.completion_usage import CompletionUsage -import re class MockChatClass(object): @staticmethod diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py index 526e7b1b39..4a33a508a1 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -1,17 +1,16 @@ -from openai import BadRequestError, OpenAI, AzureOpenAI -from openai.types import Completion as CompletionMessage -from openai._types import NotGiven, NOT_GIVEN -from openai.types.completion import CompletionChoice -from openai.types.completion_usage import CompletionUsage -from openai.resources.completions import Completions - +import re +from time import sleep, time # import monkeypatch -from typing import List, Any, Generator, Union, Optional, Literal -from time import time, sleep +from typing import Any, Generator, List, Literal, Optional, Union from core.model_runtime.errors.invoke import InvokeAuthorizationError +from openai import AzureOpenAI, BadRequestError, OpenAI +from openai._types import NOT_GIVEN, NotGiven +from openai.resources.completions import Completions +from openai.types import Completion as CompletionMessage +from openai.types.completion import CompletionChoice +from openai.types.completion_usage import CompletionUsage -import re class MockCompletionsClass(object): @staticmethod diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py index 2913571739..9c3d293281 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -1,14 +1,13 @@ -from openai.resources.embeddings import Embeddings -from openai._types import NotGiven, NOT_GIVEN -from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage -from openai.types.embedding import Embedding -from openai import OpenAI - -from typing import Union, List, Literal, Any +import re +from typing import Any, List, Literal, Union from core.model_runtime.errors.invoke import InvokeAuthorizationError +from openai import OpenAI +from openai._types import NOT_GIVEN, NotGiven +from openai.resources.embeddings import Embeddings +from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage +from openai.types.embedding import Embedding -import re class MockEmbeddingsClass(object): def create_embeddings( diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py index 81fe9e99f4..634fa77096 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -1,13 +1,12 @@ -from openai.resources.moderations import Moderations -from openai.types import ModerationCreateResponse -from openai.types.moderation import Moderation, Categories, CategoryScores -from openai._types import NotGiven, NOT_GIVEN - -from typing import Union, List, Literal, Any +import re +from typing import Any, List, Literal, Union from core.model_runtime.errors.invoke import InvokeAuthorizationError +from openai._types import NOT_GIVEN, NotGiven +from openai.resources.moderations import Moderations +from openai.types import ModerationCreateResponse +from openai.types.moderation import Categories, CategoryScores, Moderation -import re class MockModerationClass(object): def moderation_create(self: Moderations,*, diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py index 5fc14d038b..3d665ad5c3 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py @@ -1,8 +1,9 @@ +from time import time +from typing import List + from openai.resources.models import Models from openai.types.model import Model -from typing import List -from time import time class MockModelClass(object): """ diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py index ae9692f363..8032747bd1 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -1,12 +1,11 @@ -from openai.resources.audio.transcriptions import Transcriptions -from openai._types import NotGiven, NOT_GIVEN, FileTypes -from openai.types.audio.transcription import Transcription - -from typing import Union, List, Literal, Any +import re +from typing import Any, List, Literal, Union from core.model_runtime.errors.invoke import InvokeAuthorizationError +from openai._types import NOT_GIVEN, FileTypes, NotGiven +from openai.resources.audio.transcriptions import Transcriptions +from openai.types.audio.transcription import Transcription -import re class MockSpeech2TextClass(object): def speech2text_create(self: Transcriptions, diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index d0eeeffd06..f5c61f4725 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -1,17 +1,17 @@ -from xinference_client.client.restful.restful_client import Client, \ - RESTfulChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatglmCppChatModelHandle, \ - RESTfulEmbeddingModelHandle, RESTfulRerankModelHandle -from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage - -from requests.sessions import Session -from requests import Response -from requests.exceptions import ConnectionError -from typing import Union, List - -from _pytest.monkeypatch import MonkeyPatch -import pytest import os import re +from typing import List, Union + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from requests import Response +from requests.exceptions import ConnectionError +from requests.sessions import Session +from xinference_client.client.restful.restful_client import (Client, RESTfulChatglmCppChatModelHandle, + RESTfulChatModelHandle, RESTfulEmbeddingModelHandle, + RESTfulGenerateModelHandle, RESTfulRerankModelHandle) +from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage + class MockXinferenceClass(object): def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py index 276d76bed4..ddba2a40ce 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py @@ -2,15 +2,13 @@ import os from typing import Generator import pytest - -from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeLanguageModel - from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock + @pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) def test_validate_credentials(setup_anthropic_mock): model = AnthropicLargeLanguageModel() diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py index 16af242763..3ab624d351 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py @@ -1,12 +1,11 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProvider - from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock + @pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) def test_validate_provider_credentials(setup_anthropic_mock): provider = AnthropicProvider() diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py index 90a81b1d97..bf9d9ea06b 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py @@ -2,16 +2,15 @@ import os from typing import Generator import pytest - -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ - LLMResultChunk -from core.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent, \ - SystemPromptMessage, ImagePromptMessageContent, PromptMessageTool, UserPromptMessage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, + PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, UserPromptMessage) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.azure_openai.llm.llm import AzureOpenAILargeLanguageModel - from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py index 797f699688..7dca6fedda 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py @@ -1,13 +1,12 @@ import os import pytest - from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.azure_openai.text_embedding.text_embedding import AzureOpenAITextEmbeddingModel - from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + @pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) def test_validate_credentials(setup_openai_mock): model = AzureOpenAITextEmbeddingModel() diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py index 4421b5008e..d4b1523f01 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py @@ -1,16 +1,15 @@ import os -import pytest - -from typing import Generator from time import sleep +from typing import Generator -from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, SystemPromptMessage +import pytest +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ - LLMResultChunk from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.baichuan.llm.llm import BaichuanLarguageModel + def test_predefined_models(): model = BaichuanLarguageModel() model_schemas = model.predefined_models() diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py index 87b3d9a609..fc85a506ac 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py @@ -1,7 +1,6 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.baichuan.baichuan import BaichuanProvider diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py index b0a6620bb0..932e48d808 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py @@ -1,11 +1,11 @@ import os import pytest - from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.baichuan.text_embedding.text_embedding import BaichuanTextEmbeddingModel + def test_validate_credentials(): model = BaichuanTextEmbeddingModel() diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py index 0b139c9ee2..d009dbefca 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py @@ -1,18 +1,17 @@ import os -import pytest - from typing import Generator -from core.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, \ - SystemPromptMessage, PromptMessageTool +import pytest +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, + SystemPromptMessage, TextPromptMessageContent, + UserPromptMessage) from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ - LLMResultChunk from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.chatglm.llm.llm import ChatGLMLargeLanguageModel - from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + def test_predefined_models(): model = ChatGLMLargeLanguageModel() model_schemas = model.predefined_models() diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py index 3cfcf77403..4baa25a38b 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py @@ -1,12 +1,11 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider - from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = ChatGLMProvider() diff --git a/api/tests/integration_tests/model_runtime/cohere/test_provider.py b/api/tests/integration_tests/model_runtime/cohere/test_provider.py index a8f56b6194..176ba9bc07 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_provider.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_provider.py @@ -1,7 +1,6 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.cohere.cohere import CohereProvider diff --git a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py index 34546c0348..a022193f8d 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py @@ -1,6 +1,6 @@ import os -import pytest +import pytest from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.cohere.rerank.rerank import CohereRerankModel diff --git a/api/tests/integration_tests/model_runtime/google/test_llm.py b/api/tests/integration_tests/model_runtime/google/test_llm.py index 907af09941..5383b2c05b 100644 --- a/api/tests/integration_tests/model_runtime/google/test_llm.py +++ b/api/tests/integration_tests/model_runtime/google/test_llm.py @@ -2,15 +2,15 @@ import os from typing import Generator import pytest - -from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage, TextPromptMessageContent, ImagePromptMessageContent -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, + SystemPromptMessage, TextPromptMessageContent, + UserPromptMessage) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel - from tests.integration_tests.model_runtime.__mock.google import setup_google_mock + @pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) def test_validate_credentials(setup_google_mock): model = GoogleLargeLanguageModel() diff --git a/api/tests/integration_tests/model_runtime/google/test_provider.py b/api/tests/integration_tests/model_runtime/google/test_provider.py index 0478b6c409..5983ae8ba0 100644 --- a/api/tests/integration_tests/model_runtime/google/test_provider.py +++ b/api/tests/integration_tests/model_runtime/google/test_provider.py @@ -1,12 +1,11 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.google.google import GoogleProvider - from tests.integration_tests.model_runtime.__mock.google import setup_google_mock + @pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) def test_validate_provider_credentials(setup_google_mock): provider = GoogleProvider() diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py index ec96acc174..08e56bc4fe 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py @@ -2,15 +2,13 @@ import os from typing import Generator import pytest - -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta -from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.huggingface_hub.llm.llm import HuggingfaceHubLargeLanguageModel - from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock + @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) def test_hosted_inference_api_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py index e1774dd2f3..92ae289d0c 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py @@ -1,7 +1,6 @@ import os import pytest - from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.huggingface_hub.text_embedding.text_embedding import \ diff --git a/api/tests/integration_tests/model_runtime/jina/test_provider.py b/api/tests/integration_tests/model_runtime/jina/test_provider.py index 2b43248388..9568204b9d 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_provider.py +++ b/api/tests/integration_tests/model_runtime/jina/test_provider.py @@ -1,7 +1,6 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.jina.jina import JinaProvider diff --git a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py index ac17566174..d39970a23c 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py @@ -1,7 +1,6 @@ import os import pytest - from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.jina.text_embedding.text_embedding import JinaTextEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/localai/test_llm.py b/api/tests/integration_tests/model_runtime/localai/test_llm.py index 43d8eb633f..f885a67893 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/localai/test_llm.py @@ -1,16 +1,16 @@ import os -import pytest - from typing import Generator -from core.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, \ - SystemPromptMessage, PromptMessageTool +import pytest +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, + SystemPromptMessage, TextPromptMessageContent, + UserPromptMessage) from core.model_runtime.entities.model_entities import ParameterRule -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ - LLMResultChunk from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.localai.llm.llm import LocalAILarguageModel + def test_validate_credentials_for_chat_model(): model = LocalAILarguageModel() diff --git a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py index fe1ad734d5..3a1e06ab22 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py @@ -1,11 +1,11 @@ import os import pytest - from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.minimax.text_embedding.text_embedding import MinimaxTextEmbeddingModel + def test_validate_credentials(): model = MinimaxTextEmbeddingModel() diff --git a/api/tests/integration_tests/model_runtime/minimax/test_llm.py b/api/tests/integration_tests/model_runtime/minimax/test_llm.py index a3ebbf7567..05f632a583 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_llm.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_llm.py @@ -1,16 +1,15 @@ import os -import pytest - -from typing import Generator from time import sleep +from typing import Generator +import pytest +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ - LLMResultChunk from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.minimax.llm.llm import MinimaxLargeLanguageModel + def test_predefined_models(): model = MinimaxLargeLanguageModel() model_schemas = model.predefined_models() diff --git a/api/tests/integration_tests/model_runtime/minimax/test_provider.py b/api/tests/integration_tests/model_runtime/minimax/test_provider.py index 4c5462c6df..08872d704e 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_provider.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_provider.py @@ -1,7 +1,6 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.minimax.minimax import MinimaxProvider diff --git a/api/tests/integration_tests/model_runtime/openai/test_llm.py b/api/tests/integration_tests/model_runtime/openai/test_llm.py index b379758e55..c7a9b48776 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai/test_llm.py @@ -2,19 +2,19 @@ import os from typing import Generator import pytest - -from core.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, \ - SystemPromptMessage, ImagePromptMessageContent, PromptMessageTool +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, + PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, UserPromptMessage) from core.model_runtime.entities.model_entities import AIModelEntity, ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ - LLMResultChunk from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + def test_predefined_models(): model = OpenAILargeLanguageModel() model_schemas = model.predefined_models() diff --git a/api/tests/integration_tests/model_runtime/openai/test_moderation.py b/api/tests/integration_tests/model_runtime/openai/test_moderation.py index 1a1c943145..1154d76ad7 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_moderation.py +++ b/api/tests/integration_tests/model_runtime/openai/test_moderation.py @@ -1,12 +1,11 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel - from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + @pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAIModerationModel() diff --git a/api/tests/integration_tests/model_runtime/openai/test_provider.py b/api/tests/integration_tests/model_runtime/openai/test_provider.py index d667364e5c..f4eaa61c04 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/openai/test_provider.py @@ -1,12 +1,11 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.openai import OpenAIProvider - from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = OpenAIProvider() diff --git a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py index 6353743d6a..6d00ee2ea1 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py @@ -1,12 +1,11 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.speech2text.speech2text import OpenAISpeech2TextModel - from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + @pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAISpeech2TextModel() diff --git a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py index 4007222719..927903a5a0 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py @@ -1,13 +1,12 @@ import os import pytest - from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.text_embedding.text_embedding import OpenAITextEmbeddingModel - from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + @pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAITextEmbeddingModel() diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py index 8be19b7c6c..9e94e562e9 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py @@ -2,11 +2,9 @@ import os from typing import Generator import pytest - -from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \ - SystemPromptMessage, PromptMessageTool -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ - LLMResultChunk +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py index 88a23c6f99..80be869ec1 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py @@ -1,10 +1,10 @@ import os import pytest - from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import OAICompatEmbeddingModel +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import \ + OAICompatEmbeddingModel """ Using OpenAI's API as testing endpoint diff --git a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py index 341e0255bb..8b6fc6738d 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py @@ -1,11 +1,11 @@ import os import pytest - from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openllm.text_embedding.text_embedding import OpenLLMTextEmbeddingModel + def test_validate_credentials(): model = OpenLLMTextEmbeddingModel() diff --git a/api/tests/integration_tests/model_runtime/openllm/test_llm.py b/api/tests/integration_tests/model_runtime/openllm/test_llm.py index 8536dd2073..42bd48cace 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_llm.py @@ -1,14 +1,13 @@ import os -import pytest - from typing import Generator +import pytest +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ - LLMResultChunk from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openllm.llm.llm import OpenLLMLargeLanguageModel + def test_validate_credentials_for_chat_model(): model = OpenLLMLargeLanguageModel() diff --git a/api/tests/integration_tests/model_runtime/replicate/test_llm.py b/api/tests/integration_tests/model_runtime/replicate/test_llm.py index 61a4ab2807..f6768f20f8 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_llm.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_llm.py @@ -2,10 +2,8 @@ import os from typing import Generator import pytest - -from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.replicate.llm.llm import ReplicateLargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py index 5708ec9e5a..30144db74a 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py @@ -1,7 +1,6 @@ import os import pytest - from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.replicate.text_embedding.text_embedding import ReplicateEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/spark/test_llm.py b/api/tests/integration_tests/model_runtime/spark/test_llm.py index 2e3d775000..78ad71b4cf 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_llm.py +++ b/api/tests/integration_tests/model_runtime/spark/test_llm.py @@ -2,10 +2,8 @@ import os from typing import Generator import pytest - -from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.spark.llm.llm import SparkLargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/spark/test_provider.py b/api/tests/integration_tests/model_runtime/spark/test_provider.py index 8e22815a86..8f65fa1af3 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_provider.py +++ b/api/tests/integration_tests/model_runtime/spark/test_provider.py @@ -1,7 +1,6 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.spark.spark import SparkProvider diff --git a/api/tests/integration_tests/model_runtime/test_model_provider_factory.py b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py index 7551baef91..fd8aa3f610 100644 --- a/api/tests/integration_tests/model_runtime/test_model_provider_factory.py +++ b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py @@ -2,8 +2,8 @@ import logging import os from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import SimpleProviderEntity, ProviderConfig, ProviderEntity -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory, ModelProviderExtension +from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity +from core.model_runtime.model_providers.model_provider_factory import ModelProviderExtension, ModelProviderFactory logger = logging.getLogger(__name__) diff --git a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py index f4aad709c1..2581bd46c1 100644 --- a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py @@ -2,11 +2,9 @@ import os from typing import Generator import pytest - -from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \ - SystemPromptMessage, PromptMessageTool -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ - LLMResultChunk +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.togetherai.llm.llm import TogetherAILargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py index 65e57f7001..217a17d801 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py @@ -2,10 +2,8 @@ import os from typing import Generator import pytest - -from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.tongyi.llm.llm import TongyiLargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py index 6145c1dc37..4cfe5930f4 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py @@ -1,7 +1,6 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.tongyi.tongyi import TongyiProvider diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py index a636f1f064..1af21f147e 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py @@ -1,16 +1,15 @@ import os -import pytest - -from typing import Generator from time import sleep +from typing import Generator -from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, SystemPromptMessage +import pytest +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ - LLMResultChunk from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLarguageModel + def test_predefined_models(): model = ErnieBotLarguageModel() model_schemas = model.predefined_models() diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py index 8922aa1868..683135b534 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py @@ -1,7 +1,6 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.wenxin.wenxin import WenxinProvider diff --git a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py index f0ee893f75..c3f2f7083c 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py @@ -1,12 +1,11 @@ import os import pytest - from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.xinference.text_embedding.text_embedding import XinferenceTextEmbeddingModel +from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock -from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock, MOCK @pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) def test_validate_credentials(setup_xinference_mock): diff --git a/api/tests/integration_tests/model_runtime/xinference/test_llm.py b/api/tests/integration_tests/model_runtime/xinference/test_llm.py index 2974e86466..f31e6e48f5 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_llm.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_llm.py @@ -1,13 +1,12 @@ import os -import pytest - from typing import Generator -from core.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, \ - SystemPromptMessage, PromptMessageTool +import pytest +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, + SystemPromptMessage, TextPromptMessageContent, + UserPromptMessage) from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ - LLMResultChunk from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.xinference.llm.llm import XinferenceAILargeLanguageModel @@ -15,6 +14,7 @@ from core.model_runtime.model_providers.xinference.llm.llm import XinferenceAILa from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock + @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() diff --git a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py index b1197aa6ae..dd638317bd 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py @@ -1,11 +1,11 @@ import os -import pytest +import pytest from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.xinference.rerank.rerank import XinferenceRerankModel +from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock -from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock, MOCK @pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) def test_validate_credentials(setup_xinference_mock): diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py index adcaa51b35..e5a3d0ad1c 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py @@ -2,10 +2,8 @@ import os from typing import Generator import pytest - -from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ - LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py index 032e15e846..6ec65df7e3 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py @@ -1,10 +1,10 @@ import os import pytest - from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.zhipuai.zhipuai import ZhipuaiProvider + def test_validate_provider_credentials(): provider = ZhipuaiProvider() diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py index 15a9307a32..30453eafb1 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py @@ -1,7 +1,6 @@ import os import pytest - from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.zhipuai.text_embedding.text_embedding import ZhipuAITextEmbeddingModel diff --git a/dev/reformat b/dev/reformat new file mode 100755 index 0000000000..8e0baf5e11 --- /dev/null +++ b/dev/reformat @@ -0,0 +1,11 @@ +#!/bin/bash + +set -x + +# python style checks rely on `isort` in path +if ! command -v isort &> /dev/null +then + echo "Skip Python imports linting, since 'isort' is not available. Please install it with 'pip install isort'." +else + isort --settings ./.github/linters/.isort.cfg ./ +fi diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py index ac954ff831..5259d082ca 100644 --- a/sdks/python-client/tests/test_client.py +++ b/sdks/python-client/tests/test_client.py @@ -1,5 +1,6 @@ import os import unittest + from dify_client.client import ChatClient, CompletionClient, DifyClient API_KEY = os.environ.get("API_KEY")