mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
parent
e91dd28a76
commit
d069c668f8
58
.github/workflows/api-model-runtime-tests.yml
vendored
Normal file
58
.github/workflows/api-model-runtime-tests.yml
vendored
Normal file
|
@ -0,0 +1,58 @@
|
|||
name: Run Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- deploy/dev
|
||||
- feat/model-runtime
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
|
||||
AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
|
||||
AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
|
||||
ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
CHATGLM_API_BASE: http://a.abc.com:11451
|
||||
XINFERENCE_SERVER_URL: http://a.abc.com:11451
|
||||
XINFERENCE_GENERATION_MODEL_UID: generate
|
||||
XINFERENCE_CHAT_MODEL_UID: chat
|
||||
XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
|
||||
XINFERENCE_RERANK_MODEL_UID: rerank
|
||||
GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
|
||||
HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
|
||||
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
|
||||
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
|
||||
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
|
||||
MOCK_SWITCH: true
|
||||
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install -r api/requirements.txt
|
||||
|
||||
- name: Run pytest
|
||||
run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py
|
38
.github/workflows/api-unit-tests.yml
vendored
38
.github/workflows/api-unit-tests.yml
vendored
|
@ -1,38 +0,0 @@
|
|||
name: Run Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- deploy/dev
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install -r api/requirements.txt
|
||||
|
||||
- name: Run pytest
|
||||
run: pytest api/tests/unit_tests
|
|
@ -55,6 +55,11 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
|
|||
|
||||
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
|
||||
|
||||
|
||||
### Provider Integrations
|
||||
If you see a model provider not yet supported by Dify that you'd like to use, follow these [steps](api/core/model_runtime/README.md) to submit a PR.
|
||||
|
||||
|
||||
### i18n (Internationalization) Support
|
||||
|
||||
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
|
||||
|
|
15
api/.vscode/launch.json
vendored
15
api/.vscode/launch.json
vendored
|
@ -4,6 +4,21 @@
|
|||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Celery",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"justMyCode": true,
|
||||
"args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_DEBUG": "1",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Python: Flask",
|
||||
"type": "python",
|
||||
|
|
|
@ -34,9 +34,6 @@ RUN apt-get update \
|
|||
COPY --from=base /pkg /usr/local
|
||||
COPY . /app/api/
|
||||
|
||||
RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
|
||||
ENV TRANSFORMERS_OFFLINE true
|
||||
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
|
|
16
api/app.py
16
api/app.py
|
@ -6,10 +6,13 @@ from werkzeug.exceptions import Unauthorized
|
|||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
if os.environ.get("VECTOR_STORE") == 'milvus':
|
||||
# if os.environ.get("VECTOR_STORE") == 'milvus':
|
||||
import grpc.experimental.gevent
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import langchain
|
||||
langchain.verbose = True
|
||||
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
|
@ -18,9 +21,8 @@ import threading
|
|||
from flask import Flask, request, Response
|
||||
from flask_cors import CORS
|
||||
|
||||
from core.model_providers.providers import hosted
|
||||
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage, ext_mail, ext_code_based_extension
|
||||
ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
|
@ -79,8 +81,6 @@ def create_app(test_config=None) -> Flask:
|
|||
register_blueprints(app)
|
||||
register_commands(app)
|
||||
|
||||
hosted.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
@ -95,6 +95,7 @@ def initialize_extensions(app):
|
|||
ext_celery.init_app(app)
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_hosting_provider.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
|
||||
|
||||
|
@ -105,6 +106,11 @@ def load_user_from_request(request_from_flask_login):
|
|||
if request.blueprint == 'console':
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if not auth_header:
|
||||
auth_token = request.args.get('_token')
|
||||
if not auth_token:
|
||||
raise Unauthorized('Invalid Authorization token.')
|
||||
else:
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
|
|
|
@ -12,16 +12,12 @@ import qdrant_client
|
|||
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
|
||||
from tqdm import tqdm
|
||||
from flask import current_app, Flask
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from libs.password import password_pattern, valid_password, hash_password
|
||||
from libs.helper import email as email_validate
|
||||
from extensions.ext_database import db
|
||||
|
@ -327,6 +323,8 @@ def create_qdrant_indexes():
|
|||
except NotFound:
|
||||
break
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if dataset.index_struct_dict:
|
||||
|
@ -334,19 +332,23 @@ def create_qdrant_indexes():
|
|||
try:
|
||||
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
dataset.embedding_model = embedding_model.name
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
except Exception:
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
|
|
|
@ -87,7 +87,7 @@ class Config:
|
|||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.3.34"
|
||||
self.CURRENT_VERSION = "0.4.0"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
|
|
|
@ -18,7 +18,7 @@ from .auth import login, oauth, data_source_oauth, activate
|
|||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
|
||||
from .workspace import workspace, members, model_providers, account, tool_providers, models
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
|
||||
|
|
|
@ -4,6 +4,10 @@ import logging
|
|||
from datetime import datetime
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
@ -13,9 +17,7 @@ from controllers.console import api
|
|||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
|
||||
app_detail_fields_with_site
|
||||
|
@ -73,39 +75,41 @@ class AppListApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
try:
|
||||
default_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
provider_manager = ProviderManager()
|
||||
default_model_entity = provider_manager.get_default_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||
default_model = None
|
||||
default_model_entity = None
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
default_model = None
|
||||
default_model_entity = None
|
||||
|
||||
if args['model_config'] is not None:
|
||||
# validate config
|
||||
model_config_dict = args['model_config']
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
current_user.current_tenant_id,
|
||||
model_config_dict["model"]["provider"]
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not model_provider:
|
||||
if not default_model:
|
||||
if not model_instance:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
|
||||
model_config_dict["model"]["name"] = default_model.name
|
||||
model_config_dict["model"]["provider"] = model_instance.provider
|
||||
model_config_dict["model"]["name"] = model_instance.model
|
||||
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=model_config_dict,
|
||||
mode=args['mode']
|
||||
app_mode=args['mode']
|
||||
)
|
||||
|
||||
app = App(
|
||||
|
@ -129,20 +133,26 @@ class AppListApi(Resource):
|
|||
app_model_config = AppModelConfig(**model_config_template['model_config'])
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
current_user.current_tenant_id,
|
||||
app_model_config.model_dict["provider"]
|
||||
)
|
||||
model_manager = ModelManager()
|
||||
|
||||
if not model_provider:
|
||||
if not default_model:
|
||||
try:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
if not model_instance:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_dict = app_model_config.model_dict
|
||||
model_dict['provider'] = default_model.model_provider.provider_name
|
||||
model_dict['name'] = default_model.name
|
||||
model_dict['provider'] = model_instance.provider
|
||||
model_dict['name'] = model_instance.model
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
app.name = args['name']
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
|
@ -14,8 +16,7 @@ from controllers.console.app.error import AppUnavailableError, \
|
|||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from flask_restful import Resource
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
|
@ -56,8 +57,7 @@ class ChatMessageAudioApi(Resource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
|
|
@ -5,6 +5,10 @@ from typing import Generator, Union
|
|||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
@ -16,9 +20,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
|
|||
ProviderModelCurrentlyNotSupportError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
|
@ -56,7 +58,7 @@ class CompletionMessageApi(Resource):
|
|||
app_model=app_model,
|
||||
user=account,
|
||||
args=args,
|
||||
from_source='console',
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=streaming,
|
||||
is_model_config_override=True
|
||||
)
|
||||
|
@ -75,8 +77,7 @@ class CompletionMessageApi(Resource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -97,7 +98,7 @@ class CompletionMessageStopApi(Resource):
|
|||
|
||||
account = flask_login.current_user
|
||||
|
||||
PubHandler.stop(account, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
@ -132,7 +133,7 @@ class ChatMessageApi(Resource):
|
|||
app_model=app_model,
|
||||
user=account,
|
||||
args=args,
|
||||
from_source='console',
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=streaming,
|
||||
is_model_config_override=True
|
||||
)
|
||||
|
@ -151,8 +152,7 @@ class ChatMessageApi(Resource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -182,9 +182,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
|
@ -207,7 +206,7 @@ class ChatMessageStopApi(Resource):
|
|||
|
||||
account = flask_login.current_user
|
||||
|
||||
PubHandler.stop(account, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from flask_login import current_user
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
|
@ -8,8 +10,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
|||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
|
||||
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
|
||||
|
||||
class RuleGenerateApi(Resource):
|
||||
|
@ -36,8 +37,7 @@ class RuleGenerateApi(Resource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
|
||||
return rules
|
||||
|
|
|
@ -14,8 +14,9 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
|
|||
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from fields.conversation_fields import message_detail_fields, annotation_fields
|
||||
from libs.helper import uuid_value
|
||||
|
@ -208,7 +209,13 @@ class MessageMoreLikeThisApi(Resource):
|
|||
app_model = _get_app(app_id, 'completion')
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
|
||||
response = CompletionService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=streaming
|
||||
)
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
@ -220,8 +227,7 @@ class MessageMoreLikeThisApi(Resource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -249,8 +255,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(
|
||||
api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
|
@ -290,8 +295,7 @@ class MessageSuggestedQuestionApi(Resource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
|
|
|
@ -31,7 +31,7 @@ class ModelConfigResource(Resource):
|
|||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=request.json,
|
||||
mode=app.mode
|
||||
app_mode=app.mode
|
||||
)
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
|
|
|
@ -4,6 +4,8 @@ from flask import request, current_app
|
|||
from flask_login import current_user
|
||||
|
||||
from controllers.console.apikey import api_key_list, api_key_fields
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
@ -14,8 +16,7 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
|
|||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
|
@ -23,7 +24,6 @@ from extensions.ext_database import db
|
|||
from models.dataset import DocumentSegment, Document
|
||||
from models.model import UploadFile, ApiToken
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
|
@ -55,16 +55,20 @@ class DatasetListApi(Resource):
|
|||
current_user.current_tenant_id, current_user)
|
||||
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
# if len(valid_model_list) == 0:
|
||||
# raise ProviderNotInitializeError(
|
||||
# f"No Embedding Model available. Please configure a valid provider "
|
||||
# f"in the Settings -> Model Provider.")
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
embedding_models = configurations.get_models(
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item['indexing_technique'] == 'high_quality':
|
||||
|
@ -75,6 +79,7 @@ class DatasetListApi(Resource):
|
|||
item['embedding_available'] = False
|
||||
else:
|
||||
item['embedding_available'] = True
|
||||
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
|
@ -130,13 +135,20 @@ class DatasetApi(Resource):
|
|||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
# get valid model list
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
embedding_models = configurations.get_models(
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data['indexing_technique'] == 'high_quality':
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
|
|
|
@ -2,8 +2,12 @@
|
|||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request, current_app
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import desc, asc
|
||||
|
@ -18,9 +22,8 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
|
|||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.document_fields import document_with_segments_fields, document_fields, \
|
||||
dataset_and_document_fields, document_status_fields
|
||||
|
@ -272,10 +275,12 @@ class DatasetInitApi(Resource):
|
|||
args = parser.parse_args()
|
||||
if args['indexing_technique'] == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
|
|
@ -12,8 +12,9 @@ from controllers.console.app.error import ProviderNotInitializeError
|
|||
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from libs.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
@ -133,10 +134,12 @@ class DatasetDocumentSegmentApi(Resource):
|
|||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
|
@ -219,10 +222,12 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
|
@ -269,10 +274,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
|
|
|
@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
|||
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from services.dataset_service import DatasetService
|
||||
|
|
|
@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
|
|||
NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
|
@ -53,8 +53,7 @@ class ChatAudioApi(InstalledAppResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
|
|
@ -15,9 +15,10 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
|
|||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
@ -50,7 +51,7 @@ class CompletionApi(InstalledAppResource):
|
|||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
from_source='console',
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
|
@ -68,8 +69,7 @@ class CompletionApi(InstalledAppResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -84,7 +84,7 @@ class CompletionStopApi(InstalledAppResource):
|
|||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
PubHandler.stop(current_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
@ -115,7 +115,7 @@ class ChatApi(InstalledAppResource):
|
|||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
from_source='console',
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
|
@ -133,8 +133,7 @@ class ChatApi(InstalledAppResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -149,7 +148,7 @@ class ChatStopApi(InstalledAppResource):
|
|||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
PubHandler.stop(current_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
@ -175,8 +174,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Generator, Union
|
|||
|
||||
from flask import stream_with_context, Response
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse, fields, marshal_with
|
||||
from flask_restful import reqparse, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound, InternalServerError
|
||||
|
||||
|
@ -13,12 +13,14 @@ import services
|
|||
from controllers.console import api
|
||||
from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
|
||||
NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
@ -83,7 +85,13 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
|||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
|
||||
response = CompletionService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming
|
||||
)
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
@ -95,8 +103,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -123,8 +130,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
|
@ -162,8 +168,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
|
|
|
@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
|
|||
NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
|
@ -53,8 +53,7 @@ class UniversalChatAudioApi(UniversalChatResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
|
|
@ -12,9 +12,10 @@ from controllers.console import api
|
|||
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
@ -68,7 +69,7 @@ class UniversalChatApi(UniversalChatResource):
|
|||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
from_source='console',
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=True,
|
||||
is_model_config_override=True,
|
||||
)
|
||||
|
@ -87,8 +88,7 @@ class UniversalChatApi(UniversalChatResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -99,7 +99,7 @@ class UniversalChatApi(UniversalChatResource):
|
|||
|
||||
class UniversalChatStopApi(UniversalChatResource):
|
||||
def post(self, universal_app, task_id):
|
||||
PubHandler.stop(current_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
@ -125,8 +125,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
|
|
|
@ -12,8 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, \
|
|||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
|
@ -132,8 +132,7 @@ class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
import io
|
||||
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from services.provider_service import ProviderService
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class ModelProviderListApi(Resource):
|
||||
|
@ -22,13 +25,36 @@ class ModelProviderListApi(Resource):
|
|||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_type', type=str, required=False, nullable=True, location='args')
|
||||
parser.add_argument('model_type', type=str, required=False, nullable=True,
|
||||
choices=[mt.value for mt in ModelType], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type'))
|
||||
model_provider_service = ModelProviderService()
|
||||
provider_list = model_provider_service.get_provider_list(
|
||||
tenant_id=tenant_id,
|
||||
model_type=args.get('model_type')
|
||||
)
|
||||
|
||||
return provider_list
|
||||
return jsonable_encoder({"data": provider_list})
|
||||
|
||||
|
||||
class ModelProviderCredentialApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider
|
||||
)
|
||||
|
||||
return {
|
||||
"credentials": credentials
|
||||
}
|
||||
|
||||
|
||||
class ModelProviderValidateApi(Resource):
|
||||
|
@ -36,21 +62,24 @@ class ModelProviderValidateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
def post(self, provider: str):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
provider_service.custom_provider_config_validate(
|
||||
provider_name=provider_name,
|
||||
config=args['config']
|
||||
model_provider_service.provider_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
|
@ -64,26 +93,26 @@ class ModelProviderValidateApi(Resource):
|
|||
return response
|
||||
|
||||
|
||||
class ModelProviderUpdateApi(Resource):
|
||||
class ModelProviderApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
provider_service.save_custom_provider_config(
|
||||
model_provider_service.save_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
config=args['config']
|
||||
provider=provider,
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
@ -93,109 +122,36 @@ class ModelProviderUpdateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_name: str):
|
||||
def delete(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.delete_custom_provider(
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name
|
||||
provider=provider
|
||||
)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
class ModelProviderIconApi(Resource):
|
||||
"""
|
||||
Get model provider icon
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
provider_service.custom_provider_model_config_validate(
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type'],
|
||||
config=args['config']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ModelProviderModelUpdateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
try:
|
||||
provider_service.add_or_save_custom_provider_model_config(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type'],
|
||||
config=args['config']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_name: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.delete_custom_provider_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type']
|
||||
def get(self, provider: str, icon_type: str, lang: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
icon, mimetype = model_provider_service.get_model_provider_icon(
|
||||
provider=provider,
|
||||
icon_type=icon_type,
|
||||
lang=lang
|
||||
)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||
|
||||
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
|
@ -203,71 +159,36 @@ class PreferredProviderTypeUpdateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
|
||||
choices=['system', 'custom'], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.switch_preferred_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.switch_preferred_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
preferred_provider_type=args['preferred_provider_type']
|
||||
)
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
try:
|
||||
parameter_rules = provider_service.get_model_parameter_rules(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type='text-generation'
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"Current Text Generation Model is invalid. Please switch to the available model.")
|
||||
|
||||
rules = {
|
||||
k: {
|
||||
'enabled': v.enabled,
|
||||
'min': v.min,
|
||||
'max': v.max,
|
||||
'default': v.default,
|
||||
'precision': v.precision
|
||||
}
|
||||
for k, v in vars(parameter_rules).items()
|
||||
}
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
if provider_name != 'anthropic':
|
||||
raise ValueError(f'provider name {provider_name} is invalid')
|
||||
def get(self, provider: str):
|
||||
if provider != 'anthropic':
|
||||
raise ValueError(f'provider name {provider} is invalid')
|
||||
|
||||
data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
|
||||
data = BillingService.get_model_provider_payment_link(provider_name=provider,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account_id=current_user.id)
|
||||
return data
|
||||
|
@ -277,11 +198,11 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
provider_service = ProviderService()
|
||||
result = provider_service.free_quota_submit(
|
||||
def post(self, provider: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_submit(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name
|
||||
provider=provider
|
||||
)
|
||||
|
||||
return result
|
||||
|
@ -291,15 +212,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
def get(self, provider: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
result = provider_service.free_quota_qualification_verify(
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_qualification_verify(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider=provider,
|
||||
token=args['token']
|
||||
)
|
||||
|
||||
|
@ -307,19 +228,18 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
|||
|
||||
|
||||
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
|
||||
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
|
||||
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
|
||||
api.add_resource(ModelProviderModelValidateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models/validate')
|
||||
api.add_resource(ModelProviderModelUpdateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models')
|
||||
|
||||
api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
|
||||
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
|
||||
api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
|
||||
api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
|
||||
'<string:icon_type>/<string:lang>')
|
||||
|
||||
api.add_resource(PreferredProviderTypeUpdateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
|
||||
api.add_resource(ModelProviderModelParameterRuleApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
|
||||
'/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
|
||||
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
|
||||
'/workspaces/current/model-providers/<string:provider>/checkout-url')
|
||||
api.add_resource(ModelProviderFreeQuotaSubmitApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
|
||||
'/workspaces/current/model-providers/<string:provider>/free-quota-submit')
|
||||
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')
|
||||
'/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_restful import reqparse, Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from models.provider import ProviderType
|
||||
from services.provider_service import ProviderService
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class DefaultModelApi(Resource):
|
||||
|
@ -21,52 +22,20 @@ class DefaultModelApi(Resource):
|
|||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
|
||||
choices=[mt.value for mt in ModelType], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
provider_service = ProviderService()
|
||||
default_model = provider_service.get_default_model_of_model_type(
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type=args['model_type']
|
||||
)
|
||||
|
||||
if not default_model:
|
||||
return None
|
||||
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
tenant_id,
|
||||
default_model.provider_name
|
||||
)
|
||||
|
||||
if not model_provider:
|
||||
return {
|
||||
'model_name': default_model.model_name,
|
||||
'model_type': default_model.model_type,
|
||||
'model_provider': {
|
||||
'provider_name': default_model.provider_name
|
||||
}
|
||||
}
|
||||
|
||||
provider = model_provider.provider
|
||||
rst = {
|
||||
'model_name': default_model.model_name,
|
||||
'model_type': default_model.model_type,
|
||||
'model_provider': {
|
||||
'provider_name': provider.provider_name,
|
||||
'provider_type': provider.provider_type
|
||||
}
|
||||
}
|
||||
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
rst['model_provider']['quota_type'] = provider.quota_type
|
||||
rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
|
||||
rst['model_provider']['quota_limit'] = provider.quota_limit
|
||||
rst['model_provider']['quota_used'] = provider.quota_used
|
||||
|
||||
return rst
|
||||
return jsonable_encoder({
|
||||
"data": default_model_entity
|
||||
})
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -76,15 +45,26 @@ class DefaultModelApi(Resource):
|
|||
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args['model_settings']
|
||||
for model_setting in model_settings:
|
||||
if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
|
||||
raise ValueError('invalid model type')
|
||||
|
||||
if 'provider' not in model_setting:
|
||||
continue
|
||||
|
||||
if 'model' not in model_setting:
|
||||
raise ValueError('invalid model')
|
||||
|
||||
try:
|
||||
provider_service.update_default_model_of_model_type(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_service.update_default_model_of_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_setting['model_type'],
|
||||
provider_name=model_setting['provider_name'],
|
||||
model_name=model_setting['model_name']
|
||||
provider=model_setting['provider'],
|
||||
model=model_setting['model']
|
||||
)
|
||||
except Exception:
|
||||
logging.warning(f"{model_setting['model_type']} save error")
|
||||
|
@ -92,22 +72,198 @@ class DefaultModelApi(Resource):
|
|||
return {'result': 'success'}
|
||||
|
||||
|
||||
class ValidModelApi(Resource):
|
||||
class ModelProviderModelApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider
|
||||
)
|
||||
|
||||
return jsonable_encoder({
|
||||
"data": models
|
||||
})
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.save_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type'],
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type']
|
||||
)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class ModelProviderModelCredentialApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=args['model_type'],
|
||||
model=args['model']
|
||||
)
|
||||
|
||||
return {
|
||||
"credentials": credentials
|
||||
}
|
||||
|
||||
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
model_provider_service.model_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type'],
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model']
|
||||
)
|
||||
|
||||
return jsonable_encoder({
|
||||
"data": parameter_rules
|
||||
})
|
||||
|
||||
|
||||
class ModelProviderAvailableModelApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, model_type):
|
||||
ModelType.value_of(model_type)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
provider_service = ProviderService()
|
||||
valid_models = provider_service.get_valid_model_list(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
return valid_models
|
||||
return jsonable_encoder({
|
||||
"data": models
|
||||
})
|
||||
|
||||
|
||||
api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
|
||||
api.add_resource(ModelProviderModelCredentialApi,
|
||||
'/workspaces/current/model-providers/<string:provider>/models/credentials')
|
||||
api.add_resource(ModelProviderModelValidateApi,
|
||||
'/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
|
||||
|
||||
api.add_resource(ModelProviderModelParameterRuleApi,
|
||||
'/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
|
||||
api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
|
||||
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
|
||||
api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')
|
||||
|
|
|
@ -1,131 +0,0 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from models.provider import ProviderType
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
class ProviderListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
"""
|
||||
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
|
||||
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
|
||||
rest is replaced by * and the last two bits are displayed in plaintext
|
||||
|
||||
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
|
||||
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
||||
"""
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_info_list = provider_service.get_provider_list(tenant_id)
|
||||
|
||||
provider_list = [
|
||||
{
|
||||
'provider_name': p['provider_name'],
|
||||
'provider_type': p['provider_type'],
|
||||
'is_valid': p['is_valid'],
|
||||
'last_used': p['last_used'],
|
||||
'is_enabled': p['is_valid'],
|
||||
**({
|
||||
'quota_type': p['quota_type'],
|
||||
'quota_limit': p['quota_limit'],
|
||||
'quota_used': p['quota_used']
|
||||
} if p['provider_type'] == ProviderType.SYSTEM.value else {}),
|
||||
'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
|
||||
if p['config'] else None
|
||||
}
|
||||
for name, provider_info in provider_info_list.items()
|
||||
for p in provider_info['providers']
|
||||
]
|
||||
|
||||
return provider_list
|
||||
|
||||
|
||||
class ProviderTokenApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if provider == 'openai':
|
||||
args['token'] = {
|
||||
'openai_api_key': args['token']
|
||||
}
|
||||
|
||||
provider_service = ProviderService()
|
||||
try:
|
||||
provider_service.save_custom_provider_config(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider,
|
||||
config=args['token']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
|
||||
class ProviderTokenValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
if provider == 'openai':
|
||||
args['token'] = {
|
||||
'openai_api_key': args['token']
|
||||
}
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
provider_service.custom_provider_config_validate(
|
||||
provider_name=provider,
|
||||
config=args['token']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
|
||||
endpoint='workspaces_current_providers_token') # PUT for updating provider token
|
||||
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
|
||||
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
|
||||
|
||||
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
|
|
@ -34,7 +34,6 @@ tenant_fields = {
|
|||
'status': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'role': fields.String,
|
||||
'providers': fields.List(fields.Nested(provider_fields)),
|
||||
'in_trial': fields.Boolean,
|
||||
'trial_end_reason': fields.String,
|
||||
'custom_config': fields.Raw(attribute='custom_config'),
|
||||
|
|
|
@ -9,8 +9,8 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
|
|||
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
|
||||
ProviderNotSupportSpeechToTextError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppModelConfig
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
|
@ -49,8 +49,7 @@ class AudioApi(AppApiResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
|
|
@ -13,9 +13,10 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
|
|||
ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \
|
||||
ProviderModelCurrentlyNotSupportError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
@ -47,7 +48,7 @@ class CompletionApi(AppApiResource):
|
|||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
from_source='api',
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
|
@ -65,8 +66,7 @@ class CompletionApi(AppApiResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -80,7 +80,7 @@ class CompletionStopApi(AppApiResource):
|
|||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
PubHandler.stop(end_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
@ -112,7 +112,7 @@ class ChatApi(AppApiResource):
|
|||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
from_source='api',
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
|
@ -130,8 +130,7 @@ class ChatApi(AppApiResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -145,7 +144,7 @@ class ChatStopApi(AppApiResource):
|
|||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
PubHandler.stop(end_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
@ -171,8 +170,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
|
|
|
@ -4,11 +4,11 @@ import services.dataset_service
|
|||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from libs.login import current_user
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
|
@ -27,12 +27,20 @@ class DatasetApi(DatasetApiResource):
|
|||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
tenant_id, current_user)
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
embedding_models = configurations.get_models(
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item['indexing_technique'] == 'high_quality':
|
||||
|
|
|
@ -13,7 +13,7 @@ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError
|
|||
NoFileUploadedError, TooManyFilesError
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from libs.login import current_user
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
|
|
@ -4,8 +4,9 @@ from werkzeug.exceptions import NotFound
|
|||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import segment_fields
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
@ -35,10 +36,12 @@ class SegmentApi(DatasetApiResource):
|
|||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
|
@ -77,10 +80,12 @@ class SegmentApi(DatasetApiResource):
|
|||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
|
@ -167,10 +172,12 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
|
|
|
@ -10,8 +10,8 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
|
|||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
|
@ -51,8 +51,7 @@ class AudioApi(WebApiResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
|
|
@ -13,9 +13,10 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
|
|||
ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
@ -44,7 +45,7 @@ class CompletionApi(WebApiResource):
|
|||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
from_source='api',
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
|
@ -62,8 +63,7 @@ class CompletionApi(WebApiResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -77,7 +77,7 @@ class CompletionStopApi(WebApiResource):
|
|||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
PubHandler.stop(end_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
@ -105,7 +105,7 @@ class ChatApi(WebApiResource):
|
|||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
from_source='api',
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
|
@ -123,8 +123,7 @@ class ChatApi(WebApiResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -138,7 +137,7 @@ class ChatStopApi(WebApiResource):
|
|||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
PubHandler.stop(end_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
@ -164,8 +163,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
|
|
|
@ -14,8 +14,9 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
|
|||
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
|
@ -117,7 +118,14 @@ class MessageMoreLikeThisApi(WebApiResource):
|
|||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
|
||||
response = CompletionService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
@ -129,8 +137,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
@ -157,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
|||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
|
@ -195,8 +201,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
|
|
101
api/core/agent/agent/agent_llm_callback.py
Normal file
101
api/core/agent/agent/agent_llm_callback.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentLLMCallback(Callback):
|
||||
|
||||
def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
|
||||
self.agent_callback = agent_callback
|
||||
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.agent_callback.on_llm_before_invoke(
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param result: result
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.agent_callback.on_llm_after_invoke(
|
||||
result=result
|
||||
)
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param ex: exception
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.agent_callback.on_llm_error(
|
||||
error=ex
|
||||
)
|
|
@ -1,28 +1,49 @@
|
|||
from typing import List
|
||||
from typing import List, cast
|
||||
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
return model_instance.get_num_tokens(to_prompt_messages(messages))
|
||||
|
||||
def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
:param llm:
|
||||
:param model_config:
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
llm_max_tokens = model_instance.model_rules.max_tokens.max
|
||||
completion_max_tokens = model_instance.model_kwargs.max_tokens
|
||||
used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
|
||||
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return 0
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
messages
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
|
@ -6,13 +5,14 @@ from langchain.agents.openai_functions_agent.base import _format_intermediate_st
|
|||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|||
"""
|
||||
An Multi Dataset Retrieve Agent driven by Router.
|
||||
"""
|
||||
model_instance: BaseLLM
|
||||
model_config: ModelConfigEntity
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
@ -81,8 +81,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
raise e
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
|
@ -106,16 +105,39 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
content=result.message.content or "",
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
'function_call': {
|
||||
'id': result.message.tool_calls[0].id,
|
||||
**result.message.tool_calls[0].function.dict()
|
||||
} if result.message.tool_calls else None
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -133,7 +155,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_instance: BaseLLM,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
|
@ -147,7 +169,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
||||
|
@ -13,18 +13,23 @@ from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage,
|
|||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_instance: BaseLLM = None
|
||||
model_instance: BaseLLM
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
model_config: ModelConfigEntity
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
@ -38,13 +43,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_instance: BaseLLM,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
prompt = cls.create_prompt(
|
||||
|
@ -52,11 +58,12 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -67,28 +74,49 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.model_instance.model_kwargs.max_tokens
|
||||
self.model_instance.model_kwargs.max_tokens = 40
|
||||
original_max_tokens = 0
|
||||
for parameter_rule in self.model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
|
||||
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
self.model_config.parameters['max_tokens'] = 40
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
callbacks=None
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
raise e
|
||||
|
||||
function_call = result.function_call
|
||||
self.model_config.parameters['max_tokens'] = original_max_tokens
|
||||
|
||||
self.model_instance.model_kwargs.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
return True if result.message.tool_calls else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
|
@ -113,22 +141,46 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
|
||||
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
content=result.message.content or "",
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
'function_call': {
|
||||
'id': result.message.tool_calls[0].id,
|
||||
**result.message.tool_calls[0].function.dict()
|
||||
} if result.message.tool_calls else None
|
||||
}
|
||||
)
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
|
@ -158,9 +210,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
|
||||
rest_tokens = self.get_message_rest_tokens(
|
||||
self.model_config,
|
||||
messages,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
@ -210,19 +267,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
|||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if model_instance.model_provider.provider_name == 'azure_openai':
|
||||
model = model_instance.base_model_name
|
||||
if model_config.provider == 'azure_openai':
|
||||
model = model_config.model
|
||||
model = model.replace("gpt-35", "gpt-3.5")
|
||||
else:
|
||||
model = model_instance.base_model_name
|
||||
model = model_config.credentials.get("base_model_name")
|
||||
|
||||
tiktoken_ = _import_tiktoken()
|
||||
try:
|
||||
|
|
|
@ -1,158 +0,0 @@
|
|||
import json
|
||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
"""
|
||||
An Multi Dataset Retrieve Agent driven by Router.
|
||||
"""
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(self.tools) == 0:
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.tools) == 1:
|
||||
tool = next(iter(self.tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
# output = ''
|
||||
# rst_json = json.loads(rst)
|
||||
# for item in rst_json:
|
||||
# output += f'{item["content"]}\n'
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
_, observation = intermediate_steps[-1]
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
try:
|
||||
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
}
|
||||
)
|
||||
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
return agent_decision
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
|
@ -12,9 +12,7 @@ from langchain.tools import BaseTool
|
|||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.model_params import ModelMode
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
|
@ -101,8 +99,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
raise e
|
||||
|
||||
try:
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
|
@ -119,6 +116,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
|
@ -193,7 +191,7 @@ Thought: {agent_scratchpad}
|
|||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
||||
if llm_chain.model_config.mode == "chat":
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
|
@ -207,7 +205,7 @@ Thought: {agent_scratchpad}
|
|||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_instance: BaseLLM,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
|
@ -221,7 +219,7 @@ Thought: {agent_scratchpad}
|
|||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
if model_instance.model_mode == ModelMode.CHAT:
|
||||
if model_config.mode == "chat":
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
|
@ -238,10 +236,16 @@ Thought: {agent_scratchpad}
|
|||
format_instructions=format_instructions,
|
||||
input_variables=input_variables
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
|
|
|
@ -13,10 +13,11 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage,
|
|||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.model_params import ModelMode
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
|
@ -54,7 +55,7 @@ Action:
|
|||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_instance: BaseLLM = None
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
along with observatons
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
|
@ -96,15 +97,16 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
raise e
|
||||
|
||||
try:
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
|
@ -119,7 +121,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_instance:
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
|
@ -153,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
@classmethod
|
||||
|
@ -229,7 +231,7 @@ Thought: {agent_scratchpad}
|
|||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
||||
if llm_chain.model_config.mode == "chat":
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
|
@ -243,7 +245,7 @@ Thought: {agent_scratchpad}
|
|||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_instance: BaseLLM,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
|
@ -253,11 +255,12 @@ Thought: {agent_scratchpad}
|
|||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
if model_instance.model_mode == ModelMode.CHAT:
|
||||
if model_config.mode == "chat":
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
|
@ -275,9 +278,15 @@ Thought: {agent_scratchpad}
|
|||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
|
|
|
@ -4,10 +4,10 @@ from typing import Union, Optional
|
|||
|
||||
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
|
@ -15,9 +15,11 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
|
|||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import prompt_messages_to_lc_messages
|
||||
from core.helper import moderation
|
||||
from core.model_providers.error import LLMError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
@ -31,14 +33,15 @@ class PlanningStrategy(str, enum.Enum):
|
|||
|
||||
class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
model_instance: BaseLLM
|
||||
model_config: ModelConfigEntity
|
||||
tools: list[BaseTool]
|
||||
summary_model_instance: BaseLLM = None
|
||||
memory: Optional[BaseChatMemory] = None
|
||||
summary_model_config: Optional[ModelConfigEntity] = None
|
||||
memory: Optional[TokenBufferMemory] = None
|
||||
callbacks: Callbacks = None
|
||||
max_iterations: int = 6
|
||||
max_execution_time: Optional[float] = None
|
||||
early_stopping_method: str = "generate"
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||
|
||||
class Config:
|
||||
|
@ -62,34 +65,42 @@ class AgentExecutor:
|
|||
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_model_instance=self.configuration.summary_model_instance
|
||||
if self.configuration.summary_model_instance else None,
|
||||
summary_model_config=self.configuration.summary_model_config
|
||||
if self.configuration.summary_model_config else None,
|
||||
agent_llm_callback=self.configuration.agent_llm_callback,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_model_instance=self.configuration.summary_model_instance
|
||||
if self.configuration.summary_model_instance else None,
|
||||
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
||||
if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_model_config=self.configuration.summary_model_config
|
||||
if self.configuration.summary_model_config else None,
|
||||
agent_llm_callback=self.configuration.agent_llm_callback,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
self.configuration.tools = [t for t in self.configuration.tools
|
||||
if isinstance(t, DatasetRetrieverTool)
|
||||
or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
|
||||
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
||||
if self.configuration.memory else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
self.configuration.tools = [t for t in self.configuration.tools
|
||||
if isinstance(t, DatasetRetrieverTool)
|
||||
or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
verbose=True
|
||||
|
@ -104,11 +115,11 @@ class AgentExecutor:
|
|||
|
||||
def run(self, query: str) -> AgentExecuteResult:
|
||||
moderation_result = moderation.check_moderation(
|
||||
self.configuration.model_instance.model_provider,
|
||||
self.configuration.model_config,
|
||||
query
|
||||
)
|
||||
|
||||
if not moderation_result:
|
||||
if moderation_result:
|
||||
return AgentExecuteResult(
|
||||
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
strategy=self.configuration.strategy,
|
||||
|
@ -118,7 +129,6 @@ class AgentExecutor:
|
|||
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
||||
agent=self.agent,
|
||||
tools=self.configuration.tools,
|
||||
memory=self.configuration.memory,
|
||||
max_iterations=self.configuration.max_iterations,
|
||||
max_execution_time=self.configuration.max_execution_time,
|
||||
early_stopping_method=self.configuration.early_stopping_method,
|
||||
|
@ -126,8 +136,8 @@ class AgentExecutor:
|
|||
)
|
||||
|
||||
try:
|
||||
output = agent_executor.run(query)
|
||||
except LLMError as ex:
|
||||
output = agent_executor.run(input=query)
|
||||
except InvokeError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logging.exception("agent_executor run failed")
|
||||
|
|
251
api/core/app_runner/agent_app_runner.py
Normal file
251
api/core/app_runner/agent_app_runner.py
Normal file
|
@ -0,0 +1,251 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.features.agent_runner import AgentRunnerFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentApplicationRunner(AppRunner):
|
||||
"""
|
||||
Agent Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run agent application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=None,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Create MessageChain
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# add agent callback to record agent thoughts
|
||||
agent_callback = AgentLoopGatherCallbackHandler(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message,
|
||||
queue_manager=queue_manager,
|
||||
message_chain=message_chain
|
||||
)
|
||||
|
||||
# init LLM Callback
|
||||
agent_llm_callback = AgentLLMCallback(
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
|
||||
agent_runner = AgentRunnerFeature(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=app_orchestration_config.agent,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
callback=agent_callback,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# agent run
|
||||
result = agent_runner.run(
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if result:
|
||||
self._save_message_chain(
|
||||
message_chain=message_chain,
|
||||
output_text=result
|
||||
)
|
||||
|
||||
if (result
|
||||
and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
|
||||
and app_orchestration_config.prompt_template.simple_prompt_template
|
||||
):
|
||||
# Direct output if agent result exists and has pre prompt
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
stream=application_generate_entity.stream,
|
||||
text=result,
|
||||
usage=self._get_usage_of_all_agent_thoughts(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message
|
||||
)
|
||||
)
|
||||
else:
|
||||
# As normal LLM run, agent result as context
|
||||
context = result
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_tokens
|
||||
all_answer_tokens += agent_thought.answer_tokens
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
)
|
267
api/core/app_runner/app_runner.py
Normal file
267
api/core/app_runner/app_runner.py
Normal file
|
@ -0,0 +1,267 @@
|
|||
import time
|
||||
from typing import cast, Optional, List, Tuple, Generator, Union
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
|
||||
from core.file.file_obj import FileObj
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileObj],
|
||||
query: Optional[str] = None) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
:param app_record: app record
|
||||
:param model_config: model config entity
|
||||
:param prompt_template_entity: prompt template entity
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
# get prompt messages without memory and context
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
prompt_messages: List[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
if prompt_tokens + max_tokens > model_context_tokens:
|
||||
max_tokens = max(model_context_tokens - prompt_tokens, 16)
|
||||
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
model_config.parameters[parameter_rule.name] = max_tokens
|
||||
|
||||
def originze_prompt_messages(self, app_record: App,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileObj],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
-> Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
:param app_record: app record
|
||||
:param model_config: model config entity
|
||||
:param prompt_template_entity: prompt template entity
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param query: query
|
||||
:param memory: memory
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = PromptTransform()
|
||||
|
||||
# get prompt without memory and context
|
||||
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
prompt_messages, stop = prompt_transform.get_prompt(
|
||||
app_mode=app_record.mode,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
else:
|
||||
prompt_messages = prompt_transform.get_advanced_prompt(
|
||||
app_mode=app_record.mode,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
def direct_output(self, queue_manager: ApplicationQueueManager,
|
||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||
prompt_messages: list,
|
||||
text: str,
|
||||
stream: bool,
|
||||
usage: Optional[LLMUsage] = None) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
:param app_orchestration_config: app orchestration config
|
||||
:param prompt_messages: prompt messages
|
||||
:param text: text
|
||||
:param stream: stream
|
||||
:param usage: usage
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish_chunk_message(LLMResultChunk(
|
||||
model=app_orchestration_config.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=token)
|
||||
)
|
||||
))
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=LLMResult(
|
||||
model=app_orchestration_config.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage()
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||
queue_manager: ApplicationQueueManager,
|
||||
stream: bool) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param stream: stream
|
||||
:return:
|
||||
"""
|
||||
if not stream:
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
else:
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
||||
queue_manager: ApplicationQueueManager) -> None:
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=invoke_result
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||
queue_manager: ApplicationQueueManager) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
queue_manager.publish_chunk_message(result)
|
||||
|
||||
text += result.delta.message.content
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
llm_result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage
|
||||
)
|
||||
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=llm_result
|
||||
)
|
363
api/core/app_runner/basic_app_runner.py
Normal file
363
api/core/app_runner/basic_app_runner.py
Normal file
|
@ -0,0 +1,363 @@
|
|||
import logging
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
|
||||
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.features.annotation_reply import AnnotationReplyFeature
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||
from core.features.hosting_moderation import HostingModerationFeature
|
||||
from core.features.moderation import ModerationFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.moderation.base import ModerationException
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, App, MessageAnnotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasicApplicationRunner(AppRunner):
|
||||
"""
|
||||
Basic Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# moderation
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
if query:
|
||||
# annotation reply
|
||||
annotation_reply = self.query_app_annotations_to_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish_annotation_reply(
|
||||
message_annotation_id=annotation_reply.id
|
||||
)
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_orchestration_config.external_data_variables
|
||||
if external_data_tools:
|
||||
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
context = None
|
||||
if app_orchestration_config.dataset:
|
||||
context = self.retrieve_dataset_context(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_record=app_record,
|
||||
queue_manager=queue_manager,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
show_retrieve_source=app_orchestration_config.show_retrieve_source,
|
||||
dataset_config=app_orchestration_config.dataset,
|
||||
message=message,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
def moderation_for_inputs(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config_entity: app orchestration config entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
moderation_feature = ModerationFeature()
|
||||
return moderation_feature.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param user_id: user id
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_reply_feature = AnnotationReplyFeature()
|
||||
return annotation_reply_feature.query(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
external_data_fetch_feature = ExternalDataFetchFeature()
|
||||
return external_data_fetch_feature.fetch(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
def retrieve_dataset_context(self, tenant_id: str,
|
||||
app_record: App,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
model_config: ModelConfigEntity,
|
||||
dataset_config: DatasetEntity,
|
||||
show_retrieve_source: bool,
|
||||
message: Message,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dataset context
|
||||
:param tenant_id: tenant id
|
||||
:param app_record: app record
|
||||
:param queue_manager: queue manager
|
||||
:param model_config: model config
|
||||
:param dataset_config: dataset config
|
||||
:param show_retrieve_source: show retrieve source
|
||||
:param message: message
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param user_id: user id
|
||||
:param invoke_from: invoke from
|
||||
:param memory: memory
|
||||
:return:
|
||||
"""
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager,
|
||||
app_record.id,
|
||||
message.id,
|
||||
user_id,
|
||||
invoke_from
|
||||
)
|
||||
|
||||
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
|
||||
and dataset_config.retrieve_config.query_variable):
|
||||
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||
|
||||
dataset_retrieval = DatasetRetrievalFeature()
|
||||
return dataset_retrieval.retrieve(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
config=dataset_config,
|
||||
query=query,
|
||||
invoke_from=invoke_from,
|
||||
show_retrieve_source=show_retrieve_source,
|
||||
hit_callback=hit_callback,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
hosting_moderation_feature = HostingModerationFeature()
|
||||
moderation_result = hosting_moderation_feature.check(
|
||||
application_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if moderation_result:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text="I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
return moderation_result
|
483
api/core/app_runner/generate_task_pipeline.py
Normal file
483
api/core/app_runner/generate_task_pipeline.py
Normal file
|
@ -0,0 +1,483 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Union, Generator, cast, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
|
||||
from core.entities.application_entities import ApplicationGenerateEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
|
||||
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
|
||||
AnnotationReplyEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, \
|
||||
TextPromptMessageContent, PromptMessageContentType, ImagePromptMessageContent, PromptMessage
|
||||
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, Conversation, MessageAgentThought
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""
|
||||
TaskState entity
|
||||
"""
|
||||
llm_result: LLMResult
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class GenerateTaskPipeline:
|
||||
"""
|
||||
GenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
def __init__(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
self._task_state = TaskState(
|
||||
llm_result=LLMResult(
|
||||
model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
)
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
|
||||
def process(self, stream: bool) -> Union[dict, Generator]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
return self._process_stream_response()
|
||||
else:
|
||||
return self._process_blocking_response()
|
||||
|
||||
def _process_blocking_response(self) -> dict:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
raise self._handle_error(event)
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
elif isinstance(event, AnnotationReplyEvent):
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
}
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
|
||||
model = model_config.model
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
[self._task_state.llm_result.message]
|
||||
)
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
# transform usage
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
model,
|
||||
credentials,
|
||||
prompt_tokens,
|
||||
completion_tokens
|
||||
)
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.llm_result.message.content,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
# Save message
|
||||
self._save_message(event.llm_result)
|
||||
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'answer': event.llm_result.message.content,
|
||||
'metadata': {},
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
if self._task_state.metadata:
|
||||
response['metadata'] = self._task_state.metadata
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
def _process_stream_response(self) -> Generator:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
raise self._handle_error(event)
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
|
||||
model = model_config.model
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
[self._task_state.llm_result.message]
|
||||
)
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
# transform usage
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
model,
|
||||
credentials,
|
||||
prompt_tokens,
|
||||
completion_tokens
|
||||
)
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.llm_result.message.content,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
|
||||
replace_response = {
|
||||
'event': 'message_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'answer': self._task_state.llm_result.message.content,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
replace_response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(replace_response)
|
||||
|
||||
# Save message
|
||||
self._save_message(self._task_state.llm_result)
|
||||
|
||||
response = {
|
||||
'event': 'message_end',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'id': self._message.id,
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
if self._task_state.metadata:
|
||||
response['metadata'] = self._task_state.metadata
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
elif isinstance(event, AnnotationReplyEvent):
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
}
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
agent_thought = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
response = {
|
||||
'event': 'agent_thought',
|
||||
'id': agent_thought.id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'position': agent_thought.position,
|
||||
'thought': agent_thought.thought,
|
||||
'tool': agent_thought.tool,
|
||||
'tool_input': agent_thought.tool_input,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish_chunk_message(LLMResultChunk(
|
||||
model=self._task_state.llm_result.model,
|
||||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||
)
|
||||
))
|
||||
self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
|
||||
continue
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(delta_text)
|
||||
|
||||
self._task_state.llm_result.message.content += delta_text
|
||||
response = self._handle_chunk(delta_text)
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
response = {
|
||||
'event': 'message_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'answer': event.text,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield "event: ping\n\n"
|
||||
else:
|
||||
continue
|
||||
|
||||
def _save_message(self, llm_result: LLMResult) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:param llm_result: llm result
|
||||
:return:
|
||||
"""
|
||||
usage = llm_result.usage
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
|
||||
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
|
||||
if llm_result.message.content else ''
|
||||
self._message.answer_tokens = usage.completion_tokens
|
||||
self._message.answer_unit_price = usage.completion_unit_price
|
||||
self._message.answer_price_unit = usage.completion_price_unit
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.total_price = usage.total_price
|
||||
|
||||
db.session.commit()
|
||||
|
||||
message_was_created.send(
|
||||
self._message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
)
|
||||
|
||||
def _handle_chunk(self, text: str) -> dict:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'id': self._message.id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'answer': text,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
return response
|
||||
|
||||
def _handle_error(self, event: QueueErrorEvent) -> Exception:
|
||||
"""
|
||||
Handle error event.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
return InvokeAuthorizationError('Incorrect API key provided')
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
return e
|
||||
else:
|
||||
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
|
||||
def _yield_response(self, response: dict) -> str:
|
||||
"""
|
||||
Yield response.
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
return "data: " + json.dumps(response) + "\n\n"
|
||||
|
||||
def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Prompt messages to prompt for saving.
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
prompts = []
|
||||
if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
role = 'user'
|
||||
elif prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
role = 'assistant'
|
||||
elif prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
role = 'system'
|
||||
else:
|
||||
continue
|
||||
|
||||
text = ''
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append({
|
||||
"type": 'image',
|
||||
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
|
||||
"detail": content.detail.value
|
||||
})
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
prompts.append({
|
||||
"role": role,
|
||||
"text": text,
|
||||
"files": files
|
||||
})
|
||||
else:
|
||||
prompts.append({
|
||||
"role": 'user',
|
||||
"text": prompt_messages[0].content
|
||||
})
|
||||
|
||||
return prompts
|
||||
|
||||
def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
|
||||
"""
|
||||
Init output moderation.
|
||||
:return:
|
||||
"""
|
||||
app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
|
||||
sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
|
||||
|
||||
if sensitive_word_avoidance:
|
||||
return OutputModerationHandler(
|
||||
tenant_id=self._application_generate_entity.tenant_id,
|
||||
app_id=self._application_generate_entity.app_id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
on_message_replace_func=self._queue_manager.publish_message_replace
|
||||
)
|
138
api/core/app_runner/moderation_handler.py
Normal file
138
api/core/app_runner/moderation_handler.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
from flask import current_app, Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
DEFAULT_BUFFER_SIZE: int = 300
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
|
||||
rule: ModerationRule
|
||||
on_message_replace_func: Any
|
||||
|
||||
thread: Optional[threading.Thread] = None
|
||||
thread_running: bool = True
|
||||
buffer: str = ''
|
||||
is_final_chunk: bool = False
|
||||
final_output: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_direct_output(self):
|
||||
return self.final_output is not None
|
||||
|
||||
def get_final_output(self):
|
||||
return self.final_output
|
||||
|
||||
def append_new_token(self, token: str):
|
||||
self.buffer += token
|
||||
|
||||
if not self.thread:
|
||||
self.thread = self.start_thread()
|
||||
|
||||
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
|
||||
self.buffer = completion
|
||||
self.is_final_chunk = True
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=completion
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
return completion
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
else:
|
||||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
return final_output
|
||||
|
||||
def start_thread(self) -> threading.Thread:
|
||||
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
|
||||
thread = threading.Thread(target=self.worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
|
||||
})
|
||||
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
||||
def stop_thread(self):
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread_running = False
|
||||
|
||||
def worker(self, flask_app: Flask, buffer_size: int):
|
||||
with flask_app.app_context():
|
||||
current_length = 0
|
||||
while self.thread_running:
|
||||
moderation_buffer = self.buffer
|
||||
buffer_length = len(moderation_buffer)
|
||||
if not self.is_final_chunk:
|
||||
chunk_length = buffer_length - current_length
|
||||
if 0 <= chunk_length < buffer_size:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
current_length = buffer_length
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=moderation_buffer
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
continue
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
self.final_output = final_output
|
||||
else:
|
||||
final_output = result.text + self.buffer[len(moderation_buffer):]
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
||||
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
||||
try:
|
||||
moderation_factory = ModerationFactory(
|
||||
name=self.rule.type,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
config=self.rule.config
|
||||
)
|
||||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("Moderation Output error: %s", e)
|
||||
|
||||
return None
|
655
api/core/application_manager.py
Normal file
655
api/core/application_manager.py
Normal file
|
@ -0,0 +1,655 @@
|
|||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from typing import cast, Optional, Any, Union, Generator, Tuple
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app_runner.agent_app_runner import AgentApplicationRunner
|
||||
from core.app_runner.basic_app_runner import BasicApplicationRunner
|
||||
from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, AppOrchestrationConfigEntity, \
|
||||
ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \
|
||||
AdvancedCompletionPromptTemplateEntity, ExternalDataVariableEntity, DatasetEntity, DatasetRetrieveConfigEntity, \
|
||||
AgentEntity, AgentToolEntity, FileUploadEntity, SensitiveWordAvoidanceEntity, InvokeFrom
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.file.file_obj import FileObj
|
||||
from core.errors.error import QuotaExceededError, ProviderTokenNotInitError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser, Conversation, Message, MessageFile, App
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApplicationManager:
|
||||
"""
|
||||
This class is responsible for managing application
|
||||
"""
|
||||
|
||||
def generate(self, tenant_id: str,
|
||||
app_id: str,
|
||||
app_model_config_id: str,
|
||||
app_model_config_dict: dict,
|
||||
app_model_config_override: bool,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
inputs: dict[str, str],
|
||||
query: Optional[str] = None,
|
||||
files: Optional[list[FileObj]] = None,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = False,
|
||||
extras: Optional[dict[str, Any]] = None) \
|
||||
-> Union[dict, Generator]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param tenant_id: workspace ID
|
||||
:param app_id: app ID
|
||||
:param app_model_config_id: app model config id
|
||||
:param app_model_config_dict: app model config dict
|
||||
:param app_model_config_override: app model config override
|
||||
:param user: account or end user
|
||||
:param invoke_from: invoke from source
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param files: file obj list
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
:param extras: extras
|
||||
"""
|
||||
# init task id
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = ApplicationGenerateEntity(
|
||||
task_id=task_id,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
app_model_config_id=app_model_config_id,
|
||||
app_model_config_dict=app_model_config_dict,
|
||||
app_orchestration_config_entity=self._convert_from_app_model_config_dict(
|
||||
tenant_id=tenant_id,
|
||||
app_model_config_dict=app_model_config_dict
|
||||
),
|
||||
app_model_config_override=app_model_config_override,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs if conversation else inputs,
|
||||
query=query.replace('\x00', '') if query else None,
|
||||
files=files if files else [],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = ApplicationQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
return self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation_id: conversation ID
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
if application_generate_entity.app_orchestration_config_entity.agent:
|
||||
# agent app
|
||||
runner = AgentApplicationRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
else:
|
||||
# basic app
|
||||
runner = BasicApplicationRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e)
|
||||
except (ValueError, InvokeError) as e:
|
||||
queue_manager.publish_error(e)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e)
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
stream: bool = False) -> Union[dict, Generator]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:param stream: is stream
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
generate_task_pipeline = GenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process(stream=stream)
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise ConversationTaskStoppedException()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
|
||||
-> AppOrchestrationConfigEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
:param tenant_id: tenant ID
|
||||
:param app_model_config_dict: app model config dict
|
||||
:raises ProviderTokenNotInitError: provider token not init error
|
||||
:return: app orchestration config entity
|
||||
"""
|
||||
properties = {}
|
||||
|
||||
copy_app_model_config_dict = app_model_config_dict.copy()
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=copy_app_model_config_dict['model']['provider'],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_name = copy_app_model_config_dict['model']['name']
|
||||
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# check model credentials
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=copy_app_model_config_dict['model']['name']
|
||||
)
|
||||
|
||||
if model_credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=copy_app_model_config_dict['model']['name'],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
model_name = copy_app_model_config_dict['model']['name']
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = copy_app_model_config_dict['model'].get('completion_params')
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
|
||||
# get model mode
|
||||
model_mode = copy_app_model_config_dict['model'].get('mode')
|
||||
if not model_mode:
|
||||
mode_enum = model_type_instance.get_model_mode(
|
||||
model=copy_app_model_config_dict['model']['name'],
|
||||
credentials=model_credentials
|
||||
)
|
||||
|
||||
model_mode = mode_enum.value
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
copy_app_model_config_dict['model']['name'],
|
||||
model_credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
properties['model_config'] = ModelConfigEntity(
|
||||
provider=copy_app_model_config_dict['model']['provider'],
|
||||
model=copy_app_model_config_dict['model']['name'],
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
# prompt template
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
|
||||
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
|
||||
properties['prompt_template'] = PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
simple_prompt_template=simple_prompt_template
|
||||
)
|
||||
else:
|
||||
advanced_chat_prompt_template = None
|
||||
chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
|
||||
if chat_prompt_config:
|
||||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
chat_prompt_messages.append({
|
||||
"text": message["text"],
|
||||
"role": PromptMessageRole.value_of(message["role"])
|
||||
})
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
|
||||
messages=chat_prompt_messages
|
||||
)
|
||||
|
||||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params = {
|
||||
'prompt': completion_prompt_config['prompt']['text'],
|
||||
}
|
||||
|
||||
if 'conversation_histories_role' in completion_prompt_config:
|
||||
completion_prompt_template_params['role_prefix'] = {
|
||||
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
|
||||
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
**completion_prompt_template_params
|
||||
)
|
||||
|
||||
properties['prompt_template'] = PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
advanced_chat_prompt_template=advanced_chat_prompt_template,
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template
|
||||
)
|
||||
|
||||
# external data variables
|
||||
properties['external_data_variables'] = []
|
||||
external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
|
||||
for external_data_tool in external_data_tools:
|
||||
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
|
||||
continue
|
||||
|
||||
properties['external_data_variables'].append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=external_data_tool['variable'],
|
||||
type=external_data_tool['type'],
|
||||
config=external_data_tool['config']
|
||||
)
|
||||
)
|
||||
|
||||
# show retrieve source
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
|
||||
if retriever_resource_dict:
|
||||
if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
|
||||
show_retrieve_source = True
|
||||
|
||||
properties['show_retrieve_source'] = show_retrieve_source
|
||||
|
||||
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
|
||||
and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
|
||||
'enabled']:
|
||||
agent_dict = copy_app_model_config_dict.get('agent_mode')
|
||||
if agent_dict['strategy'] in ['router', 'react_router']:
|
||||
dataset_ids = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != 'dataset':
|
||||
continue
|
||||
|
||||
tool_item = tool[key]
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = tool_item['id']
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
|
||||
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
single_strategy=agent_dict['strategy']
|
||||
)
|
||||
)
|
||||
else:
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
top_k=dataset_configs.get('top_k'),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model')
|
||||
)
|
||||
)
|
||||
else:
|
||||
if agent_dict['strategy'] == 'react':
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
key = list(tool.keys())[0]
|
||||
tool_item = tool[key]
|
||||
|
||||
agent_tool_properties = {
|
||||
"tool_id": key
|
||||
}
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties["config"] = tool_item
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
|
||||
properties['agent'] = AgentEntity(
|
||||
provider=properties['model_config'].provider,
|
||||
model=properties['model_config'].model,
|
||||
strategy=strategy,
|
||||
tools=agent_tools
|
||||
)
|
||||
|
||||
# file upload
|
||||
file_upload_dict = copy_app_model_config_dict.get('file_upload')
|
||||
if file_upload_dict:
|
||||
if 'image' in file_upload_dict and file_upload_dict['image']:
|
||||
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
|
||||
properties['file_upload'] = FileUploadEntity(
|
||||
image_config={
|
||||
'number_limits': file_upload_dict['image']['number_limits'],
|
||||
'detail': file_upload_dict['image']['detail'],
|
||||
'transfer_methods': file_upload_dict['image']['transfer_methods']
|
||||
}
|
||||
)
|
||||
|
||||
# opening statement
|
||||
properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
|
||||
|
||||
# suggested questions after answer
|
||||
suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
|
||||
if suggested_questions_after_answer_dict:
|
||||
if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
|
||||
properties['suggested_questions_after_answer'] = True
|
||||
|
||||
# more like this
|
||||
more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
|
||||
if more_like_this_dict:
|
||||
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
|
||||
properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
|
||||
|
||||
# speech to text
|
||||
speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
|
||||
if speech_to_text_dict:
|
||||
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
|
||||
properties['speech_to_text'] = True
|
||||
|
||||
# sensitive word avoidance
|
||||
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
||||
if sensitive_word_avoidance_dict:
|
||||
if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
|
||||
properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get('type'),
|
||||
config=sensitive_word_avoidance_dict.get('config'),
|
||||
)
|
||||
|
||||
return AppOrchestrationConfigEntity(**properties)
|
||||
|
||||
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
||||
-> Tuple[Conversation, Message]:
|
||||
"""
|
||||
Initialize generate records
|
||||
:param application_generate_entity: application generate entity
|
||||
:return:
|
||||
"""
|
||||
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=app_orchestration_config_entity.model_config.model,
|
||||
credentials=app_orchestration_config_entity.model_config.credentials
|
||||
)
|
||||
|
||||
app_record = (db.session.query(App)
|
||||
.filter(App.id == application_generate_entity.app_id).first())
|
||||
|
||||
app_mode = app_record.mode
|
||||
|
||||
# get from source
|
||||
end_user_id = None
|
||||
account_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
from_source = 'api'
|
||||
end_user_id = application_generate_entity.user_id
|
||||
else:
|
||||
from_source = 'console'
|
||||
account_id = application_generate_entity.user_id
|
||||
|
||||
override_model_configs = None
|
||||
if application_generate_entity.app_model_config_override:
|
||||
override_model_configs = application_generate_entity.app_model_config_dict
|
||||
|
||||
introduction = ''
|
||||
if app_mode == 'chat':
|
||||
# get conversation introduction
|
||||
introduction = self._get_conversation_introduction(application_generate_entity)
|
||||
|
||||
if not application_generate_entity.conversation_id:
|
||||
conversation = Conversation(
|
||||
app_id=app_record.id,
|
||||
app_model_config_id=application_generate_entity.app_model_config_id,
|
||||
model_provider=app_orchestration_config_entity.model_config.provider,
|
||||
model_id=app_orchestration_config_entity.model_config.model,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=app_mode,
|
||||
name='New conversation',
|
||||
inputs=application_generate_entity.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status='normal',
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
else:
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(
|
||||
Conversation.id == application_generate_entity.conversation_id,
|
||||
Conversation.app_id == app_record.id
|
||||
).first()
|
||||
)
|
||||
|
||||
currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
|
||||
|
||||
message = Message(
|
||||
app_id=app_record.id,
|
||||
model_provider=app_orchestration_config_entity.model_config.provider,
|
||||
model_id=app_orchestration_config_entity.model_config.model,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query or "",
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency=currency,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
agent_based=app_orchestration_config_entity.agent is not None
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
message_id=message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id,
|
||||
created_by_role=('account' if account_id else 'end_user'),
|
||||
created_by=account_id or end_user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
return conversation, message
|
||||
|
||||
def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
|
||||
"""
|
||||
Get conversation introduction
|
||||
:param application_generate_entity: application generate entity
|
||||
:return: conversation introduction
|
||||
"""
|
||||
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
||||
introduction = app_orchestration_config_entity.opening_statement
|
||||
|
||||
if introduction:
|
||||
try:
|
||||
inputs = application_generate_entity.inputs
|
||||
prompt_template = PromptTemplateParser(template=introduction)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
introduction = prompt_template.format(prompt_inputs)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return introduction
|
||||
|
||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""
|
||||
Get conversation by conversation id
|
||||
:param conversation_id: conversation id
|
||||
:return: conversation
|
||||
"""
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
"""
|
||||
Get message by message id
|
||||
:param message_id: message id
|
||||
:return: message
|
||||
"""
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.id == message_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return message
|
228
api/core/application_queue_manager.py
Normal file
228
api/core/application_queue_manager.py
Normal file
|
@ -0,0 +1,228 @@
|
|||
import queue
|
||||
import time
|
||||
from typing import Generator, Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \
|
||||
QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \
|
||||
QueueMessageEvent, QueueMessage, AnnotationReplyEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import MessageAgentThought
|
||||
|
||||
|
||||
class ApplicationQueueManager:
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
conversation_id: str,
|
||||
app_mode: str,
|
||||
message_id: str) -> None:
|
||||
if not user_id:
|
||||
raise ValueError("user is required")
|
||||
|
||||
self._task_id = task_id
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
self._conversation_id = str(conversation_id)
|
||||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
|
||||
|
||||
q = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
|
||||
def listen(self) -> Generator:
|
||||
"""
|
||||
Listen to queue
|
||||
:return:
|
||||
"""
|
||||
# wait for 10 minutes to stop listen
|
||||
listen_timeout = 600
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
if message is None:
|
||||
break
|
||||
|
||||
yield message
|
||||
except queue.Empty:
|
||||
continue
|
||||
finally:
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= listen_timeout or self._is_stopped():
|
||||
# publish two messages to make sure the client can receive the stop signal
|
||||
# and stop listening after the stop signal processed
|
||||
self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
|
||||
self.stop_listen()
|
||||
|
||||
if elapsed_time // 10 > last_ping_time:
|
||||
self.publish(QueuePingEvent())
|
||||
last_ping_time = elapsed_time // 10
|
||||
|
||||
def stop_listen(self) -> None:
|
||||
"""
|
||||
Stop listen to queue
|
||||
:return:
|
||||
"""
|
||||
self._q.put(None)
|
||||
|
||||
def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
|
||||
"""
|
||||
Publish chunk message to channel
|
||||
|
||||
:param chunk: chunk
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageEvent(
|
||||
chunk=chunk
|
||||
))
|
||||
|
||||
def publish_message_replace(self, text: str) -> None:
|
||||
"""
|
||||
Publish message replace
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageReplaceEvent(
|
||||
text=text
|
||||
))
|
||||
|
||||
def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
|
||||
"""
|
||||
Publish retriever resources
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
|
||||
|
||||
def publish_annotation_reply(self, message_annotation_id: str) -> None:
|
||||
"""
|
||||
Publish annotation reply
|
||||
:param message_annotation_id: message annotation id
|
||||
:return:
|
||||
"""
|
||||
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
|
||||
|
||||
def publish_message_end(self, llm_result: LLMResult) -> None:
|
||||
"""
|
||||
Publish message end
|
||||
:param llm_result: llm result
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageEndEvent(llm_result=llm_result))
|
||||
self.stop_listen()
|
||||
|
||||
def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
|
||||
"""
|
||||
Publish agent thought
|
||||
:param message_agent_thought: message agent thought
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=message_agent_thought.id
|
||||
))
|
||||
|
||||
def publish_error(self, e) -> None:
|
||||
"""
|
||||
Publish error
|
||||
:param e: error
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(
|
||||
error=e
|
||||
))
|
||||
self.stop_listen()
|
||||
|
||||
def publish(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:return:
|
||||
"""
|
||||
self._check_for_sqlalchemy_models(event.dict())
|
||||
|
||||
message = QueueMessage(
|
||||
task_id=self._task_id,
|
||||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent):
|
||||
self.stop_listen()
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
||||
"""
|
||||
Set task stop flag
|
||||
:return:
|
||||
"""
|
||||
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
if result is None:
|
||||
return
|
||||
|
||||
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
if result != f"{user_prefix}-{user_id}":
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
:return:
|
||||
"""
|
||||
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
|
||||
result = redis_client.get(stopped_cache_key)
|
||||
if result is not None:
|
||||
redis_client.delete(stopped_cache_key)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _generate_task_belong_cache_key(cls, task_id: str) -> str:
|
||||
"""
|
||||
Generate task belong cache key
|
||||
:param task_id: task id
|
||||
:return:
|
||||
"""
|
||||
return f"generate_task_belong:{task_id}"
|
||||
|
||||
@classmethod
|
||||
def _generate_stopped_cache_key(cls, task_id: str) -> str:
|
||||
"""
|
||||
Generate stopped cache key
|
||||
:param task_id: task id
|
||||
:return:
|
||||
"""
|
||||
return f"generate_task_stopped:{task_id}"
|
||||
|
||||
def _check_for_sqlalchemy_models(self, data: Any):
|
||||
# from entity to dict or list
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
self._check_for_sqlalchemy_models(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
self._check_for_sqlalchemy_models(item)
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
|
||||
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed.")
|
||||
|
||||
|
||||
class ConversationTaskStoppedException(Exception):
|
||||
pass
|
|
@ -2,30 +2,40 @@ import json
|
|||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
from typing import Any, Dict, List, Union, Optional, cast
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from models.model import MessageChain, MessageAgentThought, Message
|
||||
|
||||
|
||||
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
|
||||
def __init__(self, model_config: ModelConfigEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
message: Message,
|
||||
message_chain: MessageChain) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.model_instance = model_instance
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self.model_config = model_config
|
||||
self.queue_manager = queue_manager
|
||||
self.message = message
|
||||
self.message_chain = message_chain
|
||||
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
|
||||
self.model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
self.current_chain = None
|
||||
|
||||
@property
|
||||
def agent_loops(self) -> List[AgentLoop]:
|
||||
|
@ -46,65 +56,60 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None:
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]),
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
|
||||
def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None:
|
||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||
self._current_loop.status = 'llm_end'
|
||||
if result.usage:
|
||||
self._current_loop.prompt_tokens = result.usage.prompt_tokens
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens(
|
||||
model=self.model_config.model,
|
||||
credentials=self.model_config.credentials,
|
||||
prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
|
||||
completion_message = result.message
|
||||
if completion_message.tool_calls:
|
||||
self._current_loop.completion \
|
||||
= json.dumps({'function_call': completion_message.tool_calls})
|
||||
else:
|
||||
self._current_loop.completion = completion_message.content
|
||||
|
||||
if result.usage:
|
||||
self._current_loop.completion_tokens = result.usage.completion_tokens
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens(
|
||||
model=self.model_config.model,
|
||||
credentials=self.model_config.credentials,
|
||||
prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt="\n".join([message.content for message in messages[0]]),
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
pass
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
# serialized={'name': 'OpenAI'}
|
||||
# prompts=['Answer the following questions...\nThought:']
|
||||
# kwargs={}
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt=prompts[0],
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
# kwargs={}
|
||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||
self._current_loop.status = 'llm_end'
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
completion_generation = response.generations[0][0]
|
||||
if isinstance(completion_generation, ChatGeneration):
|
||||
completion_message = completion_generation.message
|
||||
if 'function_call' in completion_message.additional_kwargs:
|
||||
self._current_loop.completion \
|
||||
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
|
||||
else:
|
||||
self._current_loop.completion = response.generations[0][0].text
|
||||
else:
|
||||
self._current_loop.completion = completion_generation.text
|
||||
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
|
@ -150,10 +155,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
if completion is not None:
|
||||
self._current_loop.completion = completion
|
||||
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
self._message_agent_thought = self._init_agent_thought()
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
|
@ -176,9 +178,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_instance, self._current_loop
|
||||
)
|
||||
self._complete_agent_thought(self._message_agent_thought)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
|
@ -202,17 +202,62 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
self._current_loop.thought = '[DONE]'
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
self._message_agent_thought = self._init_agent_thought()
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_instance, self._current_loop
|
||||
)
|
||||
self._complete_agent_thought(self._message_agent_thought)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
elif not self._current_loop and self._agent_loops:
|
||||
self._agent_loops[-1].status = 'agent_finish'
|
||||
|
||||
def _init_agent_thought(self) -> MessageAgentThought:
|
||||
message_agent_thought = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=self.message_chain.id,
|
||||
position=self._current_loop.position,
|
||||
thought=self._current_loop.thought,
|
||||
tool=self._current_loop.tool_name,
|
||||
tool_input=self._current_loop.tool_input,
|
||||
message=self._current_loop.prompt,
|
||||
message_price_unit=0,
|
||||
answer=self._current_loop.completion,
|
||||
answer_price_unit=0,
|
||||
created_by_role=('account' if self.message.from_source == 'console' else 'end_user'),
|
||||
created_by=(self.message.from_account_id
|
||||
if self.message.from_source == 'console' else self.message.from_end_user_id)
|
||||
)
|
||||
|
||||
db.session.add(message_agent_thought)
|
||||
db.session.commit()
|
||||
|
||||
self.queue_manager.publish_agent_thought(message_agent_thought)
|
||||
|
||||
return message_agent_thought
|
||||
|
||||
def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
|
||||
loop_message_tokens = self._current_loop.prompt_tokens
|
||||
loop_answer_tokens = self._current_loop.completion_tokens
|
||||
|
||||
# transform usage
|
||||
llm_usage = self.model_type_instance._calc_response_usage(
|
||||
self.model_config.model,
|
||||
self.model_config.credentials,
|
||||
loop_message_tokens,
|
||||
loop_answer_tokens
|
||||
)
|
||||
|
||||
message_agent_thought.observation = self._current_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
message_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
message_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
message_agent_thought.latency = self._current_loop.latency
|
||||
message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens
|
||||
message_agent_thought.total_price = llm_usage.total_price
|
||||
message_agent_thought.currency = llm_usage.currency
|
||||
db.session.commit()
|
||||
|
|
|
@ -1,74 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.queries = []
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_name: str = serialized.get('name')
|
||||
dataset_id = tool_name.removeprefix('dataset-')
|
||||
|
||||
try:
|
||||
input_dict = json.loads(input_str.replace("'", "\""))
|
||||
query = input_dict.get('query')
|
||||
except JSONDecodeError:
|
||||
query = input_str
|
||||
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.debug("Dataset tool on_llm_error: %s", error)
|
|
@ -1,16 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChainResult(BaseModel):
|
||||
type: str = None
|
||||
prompt: dict = None
|
||||
completion: dict = None
|
||||
|
||||
status: str = 'chain_started'
|
||||
completed: bool = False
|
||||
|
||||
started_at: float = None
|
||||
completed_at: float = None
|
||||
|
||||
agent_result: dict = None
|
||||
"""only when type is 'AgentExecutor'"""
|
|
@ -1,6 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DatasetQueryObj(BaseModel):
|
||||
dataset_id: str = None
|
||||
query: str = None
|
|
@ -1,8 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
prompt: str = ''
|
||||
prompt_tokens: int = 0
|
||||
completion: str = ''
|
||||
completion_tokens: int = 0
|
|
@ -1,17 +1,44 @@
|
|||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment
|
||||
from models.dataset import DocumentSegment, DatasetQuery
|
||||
from models.model import DatasetRetrieverResource
|
||||
|
||||
|
||||
class DatasetIndexToolCallbackHandler:
|
||||
"""Callback handler for dataset tool."""
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
self.conversation_message_task = conversation_message_task
|
||||
def __init__(self, queue_manager: ApplicationQueueManager,
|
||||
app_id: str,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> None:
|
||||
self._queue_manager = queue_manager
|
||||
self._app_id = app_id
|
||||
self._message_id = message_id
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
|
||||
def on_query(self, query: str, dataset_id: str) -> None:
|
||||
"""
|
||||
Handle query.
|
||||
"""
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source='app',
|
||||
source_app_id=self._app_id,
|
||||
created_by_role=('account'
|
||||
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||
created_by=self._user_id
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def on_tool_end(self, documents: List[Document]) -> None:
|
||||
"""Handle tool end."""
|
||||
|
@ -30,4 +57,27 @@ class DatasetIndexToolCallbackHandler:
|
|||
|
||||
def return_retriever_resource_info(self, resource: List):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
self.conversation_message_task.on_dataset_query_finish(resource)
|
||||
if resource and len(resource) > 0:
|
||||
for item in resource:
|
||||
dataset_retriever_resource = DatasetRetrieverResource(
|
||||
message_id=self._message_id,
|
||||
position=item.get('position'),
|
||||
dataset_id=item.get('dataset_id'),
|
||||
dataset_name=item.get('dataset_name'),
|
||||
document_id=item.get('document_id'),
|
||||
document_name=item.get('document_name'),
|
||||
data_source_type=item.get('data_source_type'),
|
||||
segment_id=item.get('segment_id'),
|
||||
score=item.get('score') if 'score' in item else None,
|
||||
hit_count=item.get('hit_count') if 'hit_count' else None,
|
||||
word_count=item.get('word_count') if 'word_count' in item else None,
|
||||
segment_position=item.get('segment_position') if 'segment_position' in item else None,
|
||||
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
|
||||
content=item.get('content'),
|
||||
retriever_from=item.get('retriever_from'),
|
||||
created_by=self._user_id
|
||||
)
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.commit()
|
||||
|
||||
self._queue_manager.publish_retriever_resources(resource)
|
||||
|
|
|
@ -1,284 +0,0 @@
|
|||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult, BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
|
||||
ImagePromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.moderation.base import ModerationOutputsResult, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class LLMCallbackHandler(BaseCallbackHandler):
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, model_instance: BaseLLM,
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
self.model_instance = model_instance
|
||||
self.llm_message = LLMMessage()
|
||||
self.start_at = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
self.output_moderation_handler = None
|
||||
self.init_output_moderation()
|
||||
|
||||
def init_output_moderation(self):
|
||||
app_model_config = self.conversation_message_task.app_model_config
|
||||
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
|
||||
|
||||
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
|
||||
self.output_moderation_handler = OutputModerationHandler(
|
||||
tenant_id=self.conversation_message_task.tenant_id,
|
||||
app_id=self.conversation_message_task.app.id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config")
|
||||
),
|
||||
on_message_replace_func=self.conversation_message_task.on_message_replace
|
||||
)
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
real_prompts = []
|
||||
for message in messages[0]:
|
||||
if message.type == 'human':
|
||||
role = 'user'
|
||||
elif message.type == 'ai':
|
||||
role = 'assistant'
|
||||
else:
|
||||
role = 'system'
|
||||
|
||||
real_prompts.append({
|
||||
"role": role,
|
||||
"text": message.content,
|
||||
"files": [{
|
||||
"type": file.type.value,
|
||||
"data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
|
||||
"detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
|
||||
} for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
|
||||
})
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
self.llm_message.prompt = [{
|
||||
"role": 'user',
|
||||
"text": prompts[0]
|
||||
}]
|
||||
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
|
||||
completion=response.generations[0][0].text,
|
||||
public_event=True if self.conversation_message_task.streaming else False
|
||||
)
|
||||
else:
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(self.llm_message.completion)
|
||||
|
||||
if response.llm_output and 'token_usage' in response.llm_output:
|
||||
if 'prompt_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
|
||||
if 'completion_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)])
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)])
|
||||
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
ex = ConversationTaskInterruptException()
|
||||
self.on_llm_error(error=ex)
|
||||
raise ex
|
||||
|
||||
try:
|
||||
self.conversation_message_task.append_message_text(token)
|
||||
self.llm_message.completion += token
|
||||
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.append_new_token(token)
|
||||
except ConversationTaskStoppedException as ex:
|
||||
self.on_llm_error(error=ex)
|
||||
raise ex
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
if isinstance(error, ConversationTaskStoppedException):
|
||||
if self.conversation_message_task.streaming:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
||||
if isinstance(error, ConversationTaskInterruptException):
|
||||
self.llm_message.completion = self.output_moderation_handler.get_final_output()
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message)
|
||||
else:
|
||||
logging.debug("on_llm_error: %s", error)
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
DEFAULT_BUFFER_SIZE: int = 300
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
|
||||
rule: ModerationRule
|
||||
on_message_replace_func: Any
|
||||
|
||||
thread: Optional[threading.Thread] = None
|
||||
thread_running: bool = True
|
||||
buffer: str = ''
|
||||
is_final_chunk: bool = False
|
||||
final_output: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_direct_output(self):
|
||||
return self.final_output is not None
|
||||
|
||||
def get_final_output(self):
|
||||
return self.final_output
|
||||
|
||||
def append_new_token(self, token: str):
|
||||
self.buffer += token
|
||||
|
||||
if not self.thread:
|
||||
self.thread = self.start_thread()
|
||||
|
||||
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
|
||||
self.buffer = completion
|
||||
self.is_final_chunk = True
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=completion
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
return completion
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
else:
|
||||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
return final_output
|
||||
|
||||
def start_thread(self) -> threading.Thread:
|
||||
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
|
||||
thread = threading.Thread(target=self.worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
|
||||
})
|
||||
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
||||
def stop_thread(self):
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread_running = False
|
||||
|
||||
def worker(self, flask_app: Flask, buffer_size: int):
|
||||
with flask_app.app_context():
|
||||
current_length = 0
|
||||
while self.thread_running:
|
||||
moderation_buffer = self.buffer
|
||||
buffer_length = len(moderation_buffer)
|
||||
if not self.is_final_chunk:
|
||||
chunk_length = buffer_length - current_length
|
||||
if 0 <= chunk_length < buffer_size:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
current_length = buffer_length
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=moderation_buffer
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
continue
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
self.final_output = final_output
|
||||
else:
|
||||
final_output = result.text + self.buffer[len(moderation_buffer):]
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
||||
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
||||
try:
|
||||
moderation_factory = ModerationFactory(
|
||||
name=self.rule.type,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
config=self.rule.config
|
||||
)
|
||||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error("Moderation Output error: %s", e)
|
||||
|
||||
return None
|
|
@ -1,76 +0,0 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self.agent_callback = None
|
||||
|
||||
def clear_chain_results(self) -> None:
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return True
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
if not self._current_chain_result:
|
||||
chain_type = serialized['id'][-1]
|
||||
if chain_type:
|
||||
self._current_chain_result = ChainResult(
|
||||
type=chain_type,
|
||||
prompt=inputs,
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = self._current_chain_message
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
if self._current_chain_result and self._current_chain_result.status == 'chain_started':
|
||||
self._current_chain_result.status = 'chain_ended'
|
||||
self._current_chain_result.completion = outputs
|
||||
self._current_chain_result.completed = True
|
||||
self._current_chain_result.completed_at = time.perf_counter()
|
||||
|
||||
self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result)
|
||||
|
||||
self.clear_chain_results()
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.debug("Dataset tool on_chain_error: %s", error)
|
||||
self.clear_chain_results()
|
|
@ -79,8 +79,11 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
|||
"""Run on agent action."""
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
try:
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
except ValueError:
|
||||
thought = ''
|
||||
|
||||
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
|
||||
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
|
||||
|
|
|
@ -5,15 +5,19 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
|||
from langchain.schema import LLMResult, Generation
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class LLMChain(LCLLMChain):
|
||||
model_instance: BaseLLM
|
||||
model_config: ModelConfigEntity
|
||||
"""The language model instance to use."""
|
||||
llm: BaseLanguageModel = FakeLLM(response="")
|
||||
parameters: Dict[str, Any] = {}
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
@ -23,14 +27,23 @@ class LLMChain(LCLLMChain):
|
|||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
messages = prompts[0].to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
stop=stop
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
stream=False,
|
||||
stop=stop,
|
||||
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
|
||||
model_parameters=self.parameters
|
||||
)
|
||||
|
||||
generations = [
|
||||
[Generation(text=result.content)]
|
||||
[Generation(text=result.message.content)]
|
||||
]
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
|
|
@ -1,501 +0,0 @@
|
|||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, List, Union, Tuple
|
||||
|
||||
from flask import current_app, Flask
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.file.file_obj import FileObj
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppModelConfig, Account, Conversation, EndUser
|
||||
from core.moderation.base import ModerationException, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
class Completion:
|
||||
@classmethod
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
|
||||
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
|
||||
auto_generate_name: bool = True, from_source: str = 'console'):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
query = PromptTemplateParser.remove_template_variables(query)
|
||||
|
||||
memory = None
|
||||
if conversation:
|
||||
# get memory of conversation (read-only)
|
||||
memory = cls.get_memory_from_conversation(
|
||||
tenant_id=app.tenant_id,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
return_messages=False
|
||||
)
|
||||
|
||||
inputs = conversation.inputs
|
||||
|
||||
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
||||
tenant_id=app.tenant_id,
|
||||
model_config=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
conversation_message_task = ConversationMessageTask(
|
||||
task_id=task_id,
|
||||
app=app,
|
||||
app_model_config=app_model_config,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
is_override=is_override,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
streaming=streaming,
|
||||
model_instance=final_model_instance,
|
||||
auto_generate_name=auto_generate_name
|
||||
)
|
||||
|
||||
prompt_message_files = [file.prompt_message_file for file in files]
|
||||
|
||||
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
|
||||
mode=app.mode,
|
||||
model_instance=final_model_instance,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files
|
||||
)
|
||||
|
||||
# init orchestrator rule parser
|
||||
orchestrator_rule_parser = OrchestratorRuleParser(
|
||||
tenant_id=app.tenant_id,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
|
||||
try:
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
|
||||
except ModerationException as e:
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files,
|
||||
agent_execute_result=None,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
fake_response=str(e)
|
||||
)
|
||||
return
|
||||
# check annotation reply
|
||||
annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
|
||||
if annotation_reply:
|
||||
return
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_model_config.external_data_tools_list
|
||||
if external_data_tools:
|
||||
inputs = cls.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
chain_callback=chain_callback,
|
||||
tenant_id=app.tenant_id,
|
||||
retriever_from=retriever_from
|
||||
)
|
||||
|
||||
query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
|
||||
|
||||
# run agent executor
|
||||
agent_execute_result = None
|
||||
if query_for_agent and agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query_for_agent)
|
||||
if should_use_agent:
|
||||
agent_execute_result = agent_executor.run(query_for_agent)
|
||||
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
fake_response = None
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
|
||||
PlanningStrategy.REACT_ROUTER]:
|
||||
fake_response = agent_execute_result.output
|
||||
|
||||
# run the final llm
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
fake_response=fake_response
|
||||
)
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
return
|
||||
except ChunkedEncodingError as e:
|
||||
# Interrupt by LLM (like OpenAI), handle it.
|
||||
logging.warning(f'ChunkedEncodingError: {e}')
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
|
||||
query: str):
|
||||
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
|
||||
return inputs, query
|
||||
|
||||
type = app_model_config.sensitive_word_avoidance_dict['type']
|
||||
|
||||
moderation = ModerationFactory(type, app_id, tenant_id,
|
||||
app_model_config.sensitive_word_avoidance_dict['config'])
|
||||
moderation_result = moderation.moderation_for_inputs(inputs, query)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
return inputs, query
|
||||
|
||||
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
raise ModerationException(moderation_result.preset_response)
|
||||
elif moderation_result.action == ModerationAction.OVERRIDED:
|
||||
inputs = moderation_result.inputs
|
||||
query = moderation_result.query
|
||||
|
||||
return inputs, query
|
||||
|
||||
@classmethod
|
||||
def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
|
||||
inputs: dict, query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
# Group tools by type and config
|
||||
grouped_tools = {}
|
||||
for tool in external_data_tools:
|
||||
if not tool.get("enabled"):
|
||||
continue
|
||||
|
||||
tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
|
||||
grouped_tools.setdefault(tool_key, []).append(tool)
|
||||
|
||||
results = {}
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = {}
|
||||
for tool in external_data_tools:
|
||||
if not tool.get("enabled"):
|
||||
continue
|
||||
|
||||
future = executor.submit(
|
||||
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
|
||||
inputs, query
|
||||
)
|
||||
|
||||
futures[future] = tool
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
tool_variable, result = future.result()
|
||||
results[tool_variable] = result
|
||||
|
||||
inputs.update(results)
|
||||
return inputs
|
||||
|
||||
@classmethod
|
||||
def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
|
||||
inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
with flask_app.app_context():
|
||||
tool_variable = external_data_tool.get("variable")
|
||||
tool_type = external_data_tool.get("type")
|
||||
tool_config = external_data_tool.get("config")
|
||||
|
||||
external_data_tool_factory = ExternalDataToolFactory(
|
||||
name=tool_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
variable=tool_variable,
|
||||
config=tool_config
|
||||
)
|
||||
|
||||
# query external data tool
|
||||
result = external_data_tool_factory.query(
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
return tool_variable, result
|
||||
|
||||
@classmethod
|
||||
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
|
||||
if app.mode != 'completion':
|
||||
return query
|
||||
|
||||
return inputs.get(app_model_config.dataset_query_variable, "")
|
||||
|
||||
@classmethod
|
||||
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
|
||||
inputs: dict,
|
||||
files: List[PromptMessageFile],
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
|
||||
fake_response: Optional[str]):
|
||||
prompt_transform = PromptTransform()
|
||||
|
||||
# get llm prompt
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, stop_words = prompt_transform.get_prompt(
|
||||
app_mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory,
|
||||
model_instance=model_instance
|
||||
)
|
||||
else:
|
||||
prompt_messages = prompt_transform.get_advanced_prompt(
|
||||
app_mode=mode,
|
||||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
model_config = app_model_config.model_dict
|
||||
completion_params = model_config.get("completion_params", {})
|
||||
stop_words = completion_params.get("stop", [])
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
response = model_instance.run(
|
||||
messages=prompt_messages,
|
||||
stop=stop_words if stop_words else None,
|
||||
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
|
||||
fake_response=fake_response
|
||||
)
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||
max_token_limit: int) -> str:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory_key = memory.memory_variables[0]
|
||||
external_context = memory.load_memory_variables({})
|
||||
return external_context[memory_key]
|
||||
|
||||
@classmethod
|
||||
def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
|
||||
from_source: str) -> bool:
|
||||
"""Get memory messages."""
|
||||
app_model_config = conversation_message_task.app_model_config
|
||||
app = conversation_message_task.app
|
||||
annotation_reply = app_model_config.annotation_reply_dict
|
||||
if annotation_reply['enabled']:
|
||||
try:
|
||||
score_threshold = annotation_reply.get('score_threshold', 1)
|
||||
embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
|
||||
embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
|
||||
# get embedding model
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=app.tenant_id,
|
||||
model_provider_name=embedding_provider_name,
|
||||
model_name=embedding_model_name
|
||||
)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
'annotation'
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings,
|
||||
attributes=['doc_id', 'annotation_id', 'app_id']
|
||||
)
|
||||
|
||||
documents = vector_index.search(
|
||||
conversation_message_task.query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': 1,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
if documents:
|
||||
annotation_id = documents[0].metadata['annotation_id']
|
||||
score = documents[0].metadata['score']
|
||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
if annotation:
|
||||
conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
|
||||
# insert annotation history
|
||||
AppAnnotationService.add_annotation_history(annotation.id,
|
||||
app.id,
|
||||
annotation.question,
|
||||
annotation.content,
|
||||
conversation_message_task.query,
|
||||
conversation_message_task.user.id,
|
||||
conversation_message_task.message.id,
|
||||
from_source,
|
||||
score)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning(f'Query annotation failed, exception: {str(e)}.')
|
||||
return False
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
|
||||
conversation: Conversation,
|
||||
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
|
||||
# only for calc token in memory
|
||||
memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
||||
tenant_id=tenant_id,
|
||||
model_config=app_model_config.model_dict
|
||||
)
|
||||
|
||||
# use llm config from conversation
|
||||
memory = ReadOnlyConversationTokenDBBufferSharedMemory(
|
||||
conversation=conversation,
|
||||
model_instance=memory_model_instance,
|
||||
max_token_limit=kwargs.get("max_token_limit", 2048),
|
||||
memory_key=kwargs.get("memory_key", "chat_history"),
|
||||
return_messages=kwargs.get("return_messages", True),
|
||||
input_key=kwargs.get("input_key", "input"),
|
||||
output_key=kwargs.get("output_key", "output"),
|
||||
message_limit=kwargs.get("message_limit", 10),
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
@classmethod
|
||||
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
|
||||
model_limited_tokens = model_instance.model_rules.max_tokens.max
|
||||
max_tokens = model_instance.get_model_kwargs().max_tokens
|
||||
|
||||
if model_limited_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_transform = PromptTransform()
|
||||
|
||||
# get prompt without memory and context
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, _ = prompt_transform.get_prompt(
|
||||
app_mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance
|
||||
)
|
||||
else:
|
||||
prompt_messages = prompt_transform.get_advanced_prompt(
|
||||
app_mode=mode,
|
||||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
|
||||
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
||||
|
||||
return rest_tokens
|
||||
|
||||
@classmethod
|
||||
def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_limited_tokens = model_instance.model_rules.max_tokens.max
|
||||
max_tokens = model_instance.get_model_kwargs().max_tokens
|
||||
|
||||
if model_limited_tokens is None:
|
||||
return
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
|
||||
|
||||
if prompt_tokens + max_tokens > model_limited_tokens:
|
||||
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
||||
|
||||
# update model instance max tokens
|
||||
model_kwargs = model_instance.get_model_kwargs()
|
||||
model_kwargs.max_tokens = max_tokens
|
||||
model_instance.set_model_kwargs(model_kwargs)
|
|
@ -1,517 +0,0 @@
|
|||
import json
|
||||
import time
|
||||
from typing import Optional, Union, List
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DatasetQuery
|
||||
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
|
||||
MessageChain, DatasetRetrieverResource, MessageFile
|
||||
|
||||
|
||||
class ConversationMessageTask:
|
||||
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
||||
inputs: dict, query: str, files: List[FileObj], streaming: bool,
|
||||
model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
|
||||
auto_generate_name: bool = True):
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
self.task_id = task_id
|
||||
|
||||
self.app = app
|
||||
self.tenant_id = app.tenant_id
|
||||
self.app_model_config = app_model_config
|
||||
self.is_override = is_override
|
||||
|
||||
self.user = user
|
||||
self.inputs = inputs
|
||||
self.query = query
|
||||
self.files = files
|
||||
self.streaming = streaming
|
||||
|
||||
self.conversation = conversation
|
||||
self.is_new_conversation = False
|
||||
|
||||
self.model_instance = model_instance
|
||||
|
||||
self.message = None
|
||||
|
||||
self.retriever_resource = None
|
||||
self.auto_generate_name = auto_generate_name
|
||||
|
||||
self.model_dict = self.app_model_config.model_dict
|
||||
self.provider_name = self.model_dict.get('provider')
|
||||
self.model_name = self.model_dict.get('name')
|
||||
self.mode = app.mode
|
||||
|
||||
self.init()
|
||||
|
||||
self._pub_handler = PubHandler(
|
||||
user=self.user,
|
||||
task_id=self.task_id,
|
||||
message=self.message,
|
||||
conversation=self.conversation,
|
||||
chain_pub=False, # disabled currently
|
||||
agent_thought_pub=True
|
||||
)
|
||||
|
||||
def init(self):
|
||||
|
||||
override_model_configs = None
|
||||
if self.is_override:
|
||||
override_model_configs = self.app_model_config.to_dict()
|
||||
|
||||
introduction = ''
|
||||
system_instruction = ''
|
||||
system_instruction_tokens = 0
|
||||
if self.mode == 'chat':
|
||||
introduction = self.app_model_config.opening_statement
|
||||
if introduction:
|
||||
prompt_template = PromptTemplateParser(template=introduction)
|
||||
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
|
||||
try:
|
||||
introduction = prompt_template.format(prompt_inputs)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if self.app_model_config.pre_prompt:
|
||||
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
|
||||
system_instruction = system_message.content
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_provider_name=self.provider_name,
|
||||
model_name=self.model_name
|
||||
)
|
||||
system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
|
||||
|
||||
if not self.conversation:
|
||||
self.is_new_conversation = True
|
||||
self.conversation = Conversation(
|
||||
app_id=self.app.id,
|
||||
app_model_config_id=self.app_model_config.id,
|
||||
model_provider=self.provider_name,
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=self.mode,
|
||||
name='New conversation',
|
||||
inputs=self.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction=system_instruction,
|
||||
system_instruction_tokens=system_instruction_tokens,
|
||||
status='normal',
|
||||
from_source=('console' if isinstance(self.user, Account) else 'api'),
|
||||
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
|
||||
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
|
||||
)
|
||||
|
||||
db.session.add(self.conversation)
|
||||
db.session.commit()
|
||||
|
||||
self.message = Message(
|
||||
app_id=self.app.id,
|
||||
model_provider=self.provider_name,
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=self.conversation.id,
|
||||
inputs=self.inputs,
|
||||
query=self.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency=self.model_instance.get_currency(),
|
||||
from_source=('console' if isinstance(self.user, Account) else 'api'),
|
||||
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
|
||||
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
|
||||
agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
|
||||
)
|
||||
|
||||
db.session.add(self.message)
|
||||
db.session.commit()
|
||||
|
||||
for file in self.files:
|
||||
message_file = MessageFile(
|
||||
message_id=self.message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
def append_message_text(self, text: str):
|
||||
if text is not None:
|
||||
self._pub_handler.pub_text(text)
|
||||
|
||||
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
|
||||
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
|
||||
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
||||
total_price = message_total_price + answer_total_price
|
||||
|
||||
self.message.message = llm_message.prompt
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.message_price_unit = message_price_unit
|
||||
self.message.answer = PromptTemplateParser.remove_template_variables(
|
||||
llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
self.message.answer_price_unit = answer_price_unit
|
||||
self.message.provider_response_latency = time.perf_counter() - self.start_at
|
||||
self.message.total_price = total_price
|
||||
|
||||
db.session.commit()
|
||||
|
||||
message_was_created.send(
|
||||
self.message,
|
||||
conversation=self.conversation,
|
||||
is_first_message=self.is_new_conversation,
|
||||
auto_generate_name=self.auto_generate_name
|
||||
)
|
||||
|
||||
if not by_stopped:
|
||||
self.end()
|
||||
|
||||
def init_chain(self, chain_result: ChainResult):
|
||||
message_chain = MessageChain(
|
||||
message_id=self.message.id,
|
||||
type=chain_result.type,
|
||||
input=json.dumps(chain_result.prompt),
|
||||
output=''
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
|
||||
message_chain.output = json.dumps(chain_result.completion)
|
||||
db.session.commit()
|
||||
|
||||
self._pub_handler.pub_chain(message_chain)
|
||||
|
||||
def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
|
||||
message_agent_thought = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=message_chain.id,
|
||||
position=agent_loop.position,
|
||||
thought=agent_loop.thought,
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
message=agent_loop.prompt,
|
||||
message_price_unit=0,
|
||||
answer=agent_loop.completion,
|
||||
answer_price_unit=0,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(message_agent_thought)
|
||||
db.session.commit()
|
||||
|
||||
self._pub_handler.pub_agent_thought(message_agent_thought)
|
||||
|
||||
return message_agent_thought
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
|
||||
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
|
||||
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
|
||||
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||
|
||||
message_agent_thought.observation = agent_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = agent_message_unit_price
|
||||
message_agent_thought.message_price_unit = agent_message_price_unit
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = agent_answer_unit_price
|
||||
message_agent_thought.answer_price_unit = agent_answer_price_unit
|
||||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
message_agent_thought.currency = agent_model_instance.get_currency()
|
||||
db.session.commit()
|
||||
|
||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_query_obj.dataset_id,
|
||||
content=dataset_query_obj.query,
|
||||
source='app',
|
||||
source_app_id=self.app.id,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def on_dataset_query_finish(self, resource: List):
|
||||
if resource and len(resource) > 0:
|
||||
for item in resource:
|
||||
dataset_retriever_resource = DatasetRetrieverResource(
|
||||
message_id=self.message.id,
|
||||
position=item.get('position'),
|
||||
dataset_id=item.get('dataset_id'),
|
||||
dataset_name=item.get('dataset_name'),
|
||||
document_id=item.get('document_id'),
|
||||
document_name=item.get('document_name'),
|
||||
data_source_type=item.get('data_source_type'),
|
||||
segment_id=item.get('segment_id'),
|
||||
score=item.get('score') if 'score' in item else None,
|
||||
hit_count=item.get('hit_count') if 'hit_count' else None,
|
||||
word_count=item.get('word_count') if 'word_count' in item else None,
|
||||
segment_position=item.get('segment_position') if 'segment_position' in item else None,
|
||||
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
|
||||
content=item.get('content'),
|
||||
retriever_from=item.get('retriever_from'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.commit()
|
||||
self.retriever_resource = resource
|
||||
|
||||
def on_message_replace(self, text: str):
|
||||
if text is not None:
|
||||
self._pub_handler.pub_message_replace(text)
|
||||
|
||||
def message_end(self):
|
||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
||||
|
||||
def end(self):
|
||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
|
||||
self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
|
||||
class PubHandler:
|
||||
def __init__(self, user: Union[Account, EndUser], task_id: str,
|
||||
message: Message, conversation: Conversation,
|
||||
chain_pub: bool = False, agent_thought_pub: bool = False):
|
||||
self._channel = PubHandler.generate_channel_name(user, task_id)
|
||||
self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
|
||||
|
||||
self._task_id = task_id
|
||||
self._message = message
|
||||
self._conversation = conversation
|
||||
self._chain_pub = chain_pub
|
||||
self._agent_thought_pub = agent_thought_pub
|
||||
|
||||
@classmethod
|
||||
def generate_channel_name(cls, user: Union[Account, EndUser], task_id: str):
|
||||
if not user:
|
||||
raise ValueError("user is required")
|
||||
|
||||
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
|
||||
return "generate_result:{}-{}".format(user_str, task_id)
|
||||
|
||||
@classmethod
|
||||
def generate_stopped_cache_key(cls, user: Union[Account, EndUser], task_id: str):
|
||||
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
|
||||
return "generate_result_stopped:{}-{}".format(user_str, task_id)
|
||||
|
||||
def pub_text(self, text: str):
|
||||
content = {
|
||||
'event': 'message',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': str(self._message.id),
|
||||
'text': text,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': str(self._conversation.id)
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_message_replace(self, text: str):
|
||||
content = {
|
||||
'event': 'message_replace',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': str(self._message.id),
|
||||
'text': text,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': str(self._conversation.id)
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_chain(self, message_chain: MessageChain):
|
||||
if self._chain_pub:
|
||||
content = {
|
||||
'event': 'chain',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'chain_id': message_chain.id,
|
||||
'type': message_chain.type,
|
||||
'input': json.loads(message_chain.input),
|
||||
'output': json.loads(message_chain.output),
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
|
||||
if self._agent_thought_pub:
|
||||
content = {
|
||||
'event': 'agent_thought',
|
||||
'data': {
|
||||
'id': message_agent_thought.id,
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'chain_id': message_agent_thought.message_chain_id,
|
||||
'position': message_agent_thought.position,
|
||||
'thought': message_agent_thought.thought,
|
||||
'tool': message_agent_thought.tool,
|
||||
'tool_input': message_agent_thought.tool_input,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_message_end(self, retriever_resource: List):
|
||||
content = {
|
||||
'event': 'message_end',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id,
|
||||
}
|
||||
}
|
||||
if retriever_resource:
|
||||
content['data']['retriever_resources'] = retriever_resource
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
|
||||
content = {
|
||||
'event': 'annotation',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id,
|
||||
'text': text,
|
||||
'annotation_id': annotation_id,
|
||||
'annotation_author_name': annotation_author_name
|
||||
}
|
||||
}
|
||||
self._message.answer = text
|
||||
self._message.provider_response_latency = time.perf_counter() - start_at
|
||||
|
||||
db.session.commit()
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_end(self):
|
||||
content = {
|
||||
'event': 'end',
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
@classmethod
|
||||
def pub_error(cls, user: Union[Account, EndUser], task_id: str, e):
|
||||
content = {
|
||||
'error': type(e).__name__,
|
||||
'description': e.description if getattr(e, 'description', None) is not None else str(e)
|
||||
}
|
||||
|
||||
channel = cls.generate_channel_name(user, task_id)
|
||||
redis_client.publish(channel, json.dumps(content))
|
||||
|
||||
def _is_stopped(self):
|
||||
return redis_client.get(self._stopped_cache_key) is not None
|
||||
|
||||
@classmethod
|
||||
def ping(cls, user: Union[Account, EndUser], task_id: str):
|
||||
content = {
|
||||
'event': 'ping'
|
||||
}
|
||||
|
||||
channel = cls.generate_channel_name(user, task_id)
|
||||
redis_client.publish(channel, json.dumps(content))
|
||||
|
||||
@classmethod
|
||||
def stop(cls, user: Union[Account, EndUser], task_id: str):
|
||||
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
|
||||
class ConversationTaskStoppedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationTaskInterruptException(Exception):
|
||||
pass
|
|
@ -1,9 +1,11 @@
|
|||
from typing import Any, Dict, Optional, Sequence
|
||||
from typing import Any, Dict, Optional, Sequence, cast
|
||||
|
||||
from langchain.schema import Document
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
@ -69,10 +71,12 @@ class DatasetDocumentStore:
|
|||
max_position = 0
|
||||
embedding_model = None
|
||||
if self._dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_provider_name=self._dataset.embedding_model_provider,
|
||||
model_name=self._dataset.embedding_model
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
|
@ -89,7 +93,16 @@ class DatasetDocumentStore:
|
|||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
|
||||
if embedding_model:
|
||||
model_type_instance = embedding_model.model_type_instance
|
||||
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
|
||||
tokens = model_type_instance.get_num_tokens(
|
||||
model=embedding_model.model,
|
||||
credentials=embedding_model.credentials,
|
||||
texts=[doc.page_content]
|
||||
)
|
||||
else:
|
||||
tokens = 0
|
||||
|
||||
if not segment_document:
|
||||
max_position += 1
|
||||
|
|
|
@ -1,19 +1,22 @@
|
|||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_manager import ModelInstance
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheEmbedding(Embeddings):
|
||||
def __init__(self, embeddings: BaseEmbedding):
|
||||
self._embeddings = embeddings
|
||||
def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
|
||||
self._model_instance = model_instance
|
||||
self._user = user
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
@ -22,7 +25,7 @@ class CacheEmbedding(Embeddings):
|
|||
embedding_queue_indices = []
|
||||
for i, text in enumerate(texts):
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
||||
if embedding:
|
||||
text_embeddings[i] = embedding.get_embedding()
|
||||
else:
|
||||
|
@ -30,15 +33,21 @@ class CacheEmbedding(Embeddings):
|
|||
|
||||
if embedding_queue_indices:
|
||||
try:
|
||||
embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices])
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=[texts[i] for i in embedding_queue_indices],
|
||||
user=self._user
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings
|
||||
except Exception as ex:
|
||||
raise self._embeddings.handle_exceptions(ex)
|
||||
logger.error('Failed to embed documents: ', ex)
|
||||
raise ex
|
||||
|
||||
for i, indice in enumerate(embedding_queue_indices):
|
||||
hash = helper.generate_text_hash(texts[indice])
|
||||
|
||||
try:
|
||||
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
|
||||
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
||||
vector = embedding_results[i]
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
text_embeddings[indice] = normalized_embedding
|
||||
|
@ -58,18 +67,23 @@ class CacheEmbedding(Embeddings):
|
|||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
||||
if embedding:
|
||||
return embedding.get_embedding()
|
||||
|
||||
try:
|
||||
embedding_results = self._embeddings.client.embed_query(text)
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=[text],
|
||||
user=self._user
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
except Exception as ex:
|
||||
raise self._embeddings.handle_exceptions(ex)
|
||||
raise ex
|
||||
|
||||
try:
|
||||
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
|
||||
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
||||
embedding.set_embedding(embedding_results)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
|
@ -79,4 +93,3 @@ class CacheEmbedding(Embeddings):
|
|||
logging.exception('Failed to add embedding to db')
|
||||
|
||||
return embedding_results
|
||||
|
||||
|
|
265
api/core/entities/application_entities.py
Normal file
265
api/core/entities/application_entities.py
Normal file
|
@ -0,0 +1,265 @@
|
|||
from enum import Enum
|
||||
from typing import Optional, Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
|
||||
class ModelConfigEntity(BaseModel):
|
||||
"""
|
||||
Model Config Entity.
|
||||
"""
|
||||
provider: str
|
||||
model: str
|
||||
model_schema: AIModelEntity
|
||||
mode: str
|
||||
provider_model_bundle: ProviderModelBundle
|
||||
credentials: dict[str, Any] = {}
|
||||
parameters: dict[str, Any] = {}
|
||||
stop: list[str] = []
|
||||
|
||||
|
||||
class AdvancedChatMessageEntity(BaseModel):
|
||||
"""
|
||||
Advanced Chat Message Entity.
|
||||
"""
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
|
||||
|
||||
class AdvancedChatPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Advanced Chat Prompt Template Entity.
|
||||
"""
|
||||
messages: list[AdvancedChatMessageEntity]
|
||||
|
||||
|
||||
class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Advanced Completion Prompt Template Entity.
|
||||
"""
|
||||
class RolePrefixEntity(BaseModel):
|
||||
"""
|
||||
Role Prefix Entity.
|
||||
"""
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
prompt: str
|
||||
role_prefix: Optional[RolePrefixEntity] = None
|
||||
|
||||
|
||||
class PromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Prompt Template Entity.
|
||||
"""
|
||||
class PromptType(Enum):
|
||||
"""
|
||||
Prompt Type.
|
||||
'simple', 'advanced'
|
||||
"""
|
||||
SIMPLE = 'simple'
|
||||
ADVANCED = 'advanced'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'PromptType':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid prompt type value {value}')
|
||||
|
||||
prompt_type: PromptType
|
||||
simple_prompt_template: Optional[str] = None
|
||||
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
|
||||
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
External Data Variable Entity.
|
||||
"""
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
|
||||
class DatasetRetrieveConfigEntity(BaseModel):
|
||||
"""
|
||||
Dataset Retrieve Config Entity.
|
||||
"""
|
||||
class RetrieveStrategy(Enum):
|
||||
"""
|
||||
Dataset Retrieve Strategy.
|
||||
'single' or 'multiple'
|
||||
"""
|
||||
SINGLE = 'single'
|
||||
MULTIPLE = 'multiple'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'RetrieveStrategy':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid retrieve strategy value {value}')
|
||||
|
||||
query_variable: Optional[str] = None # Only when app mode is completion
|
||||
|
||||
retrieve_strategy: RetrieveStrategy
|
||||
single_strategy: Optional[str] = None # for temp
|
||||
top_k: Optional[int] = None
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_model: Optional[dict] = None
|
||||
|
||||
|
||||
class DatasetEntity(BaseModel):
|
||||
"""
|
||||
Dataset Config Entity.
|
||||
"""
|
||||
dataset_ids: list[str]
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
|
||||
class FileUploadEntity(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
image_config: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
tool_id: str
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
"""
|
||||
class Strategy(Enum):
|
||||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
CHAIN_OF_THOUGHT = 'chain-of-thought'
|
||||
FUNCTION_CALLING = 'function-calling'
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
strategy: Strategy
|
||||
tools: list[AgentToolEntity] = []
|
||||
|
||||
|
||||
class AppOrchestrationConfigEntity(BaseModel):
|
||||
"""
|
||||
App Orchestration Config Entity.
|
||||
"""
|
||||
model_config: ModelConfigEntity
|
||||
prompt_template: PromptTemplateEntity
|
||||
external_data_variables: list[ExternalDataVariableEntity] = []
|
||||
agent: Optional[AgentEntity] = None
|
||||
|
||||
# features
|
||||
dataset: Optional[DatasetEntity] = None
|
||||
file_upload: Optional[FileUploadEntity] = None
|
||||
opening_statement: Optional[str] = None
|
||||
suggested_questions_after_answer: bool = False
|
||||
show_retrieve_source: bool = False
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
class InvokeFrom(Enum):
|
||||
"""
|
||||
Invoke From.
|
||||
"""
|
||||
SERVICE_API = 'service-api'
|
||||
WEB_APP = 'web-app'
|
||||
EXPLORE = 'explore'
|
||||
DEBUGGER = 'debugger'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'InvokeFrom':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid invoke from value {value}')
|
||||
|
||||
def to_source(self) -> str:
|
||||
"""
|
||||
Get source of invoke from.
|
||||
|
||||
:return: source
|
||||
"""
|
||||
if self == InvokeFrom.WEB_APP:
|
||||
return 'web_app'
|
||||
elif self == InvokeFrom.DEBUGGER:
|
||||
return 'dev'
|
||||
elif self == InvokeFrom.EXPLORE:
|
||||
return 'explore_app'
|
||||
elif self == InvokeFrom.SERVICE_API:
|
||||
return 'api'
|
||||
|
||||
return 'dev'
|
||||
|
||||
|
||||
class ApplicationGenerateEntity(BaseModel):
|
||||
"""
|
||||
Application Generate Entity.
|
||||
"""
|
||||
task_id: str
|
||||
tenant_id: str
|
||||
|
||||
app_id: str
|
||||
app_model_config_id: str
|
||||
# for save
|
||||
app_model_config_dict: dict
|
||||
app_model_config_override: bool
|
||||
|
||||
# Converted from app_model_config to Entity object, or directly covered by external input
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
inputs: dict[str, str]
|
||||
query: Optional[str] = None
|
||||
files: list[FileObj] = []
|
||||
user_id: str
|
||||
|
||||
# extras
|
||||
stream: bool
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
# extra parameters, like: auto_generate_conversation_name
|
||||
extras: dict[str, Any] = {}
|
128
api/core/entities/message_entities.py
Normal file
128
api/core/entities/message_entities.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
import enum
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage, TextPromptMessageContent, \
|
||||
ImagePromptMessageContent, AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage
|
||||
|
||||
|
||||
class PromptMessageFileType(enum.Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in PromptMessageFileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class PromptMessageFile(BaseModel):
|
||||
type: PromptMessageFileType
|
||||
data: Any
|
||||
|
||||
|
||||
class ImagePromptMessageFile(PromptMessageFile):
|
||||
class DETAIL(enum.Enum):
|
||||
LOW = 'low'
|
||||
HIGH = 'high'
|
||||
|
||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class LCHumanMessageWithFiles(HumanMessage):
|
||||
# content: Union[str, List[Union[str, Dict]]]
|
||||
content: str
|
||||
files: list[PromptMessageFile]
|
||||
|
||||
|
||||
def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
if isinstance(message, LCHumanMessageWithFiles):
|
||||
file_prompt_message_contents = []
|
||||
for file in message.files:
|
||||
if file.type == PromptMessageFileType.IMAGE:
|
||||
file = cast(ImagePromptMessageFile, file)
|
||||
file_prompt_message_contents.append(ImagePromptMessageContent(
|
||||
data=file.data,
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH
|
||||
if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
|
||||
))
|
||||
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.content)]
|
||||
prompt_message_contents.extend(file_prompt_message_contents)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.content))
|
||||
elif isinstance(message, AIMessage):
|
||||
message_kwargs = {
|
||||
'content': message.content
|
||||
}
|
||||
|
||||
if 'function_call' in message.additional_kwargs:
|
||||
message_kwargs['tool_calls'] = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=message.additional_kwargs['function_call']['id'],
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=message.additional_kwargs['function_call']['name'],
|
||||
arguments=message.additional_kwargs['function_call']['arguments']
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
prompt_messages.append(AssistantPromptMessage(**message_kwargs))
|
||||
elif isinstance(message, SystemMessage):
|
||||
prompt_messages.append(SystemPromptMessage(content=message.content))
|
||||
elif isinstance(message, FunctionMessage):
|
||||
prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
|
||||
messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, str):
|
||||
messages.append(HumanMessage(content=prompt_message.content))
|
||||
else:
|
||||
message_contents = []
|
||||
for content in prompt_message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
message_contents.append(content.data)
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
message_contents.append({
|
||||
'type': 'image',
|
||||
'data': content.data,
|
||||
'detail': content.detail.value
|
||||
})
|
||||
|
||||
messages.append(HumanMessage(content=message_contents))
|
||||
elif isinstance(prompt_message, AssistantPromptMessage):
|
||||
message_kwargs = {
|
||||
'content': prompt_message.content
|
||||
}
|
||||
|
||||
if prompt_message.tool_calls:
|
||||
message_kwargs['additional_kwargs'] = {
|
||||
'function_call': {
|
||||
'id': prompt_message.tool_calls[0].id,
|
||||
'name': prompt_message.tool_calls[0].function.name,
|
||||
'arguments': prompt_message.tool_calls[0].function.arguments
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(AIMessage(**message_kwargs))
|
||||
elif isinstance(prompt_message, SystemPromptMessage):
|
||||
messages.append(SystemMessage(content=prompt_message.content))
|
||||
elif isinstance(prompt_message, ToolPromptMessage):
|
||||
messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
|
||||
|
||||
return messages
|
71
api/core/entities/model_entities.py
Normal file
71
api/core/entities/model_entities.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ProviderModel, ModelType
|
||||
from core.model_runtime.entities.provider_entities import SimpleProviderEntity, ProviderEntity
|
||||
|
||||
|
||||
class ModelStatus(Enum):
|
||||
"""
|
||||
Enum class for model status.
|
||||
"""
|
||||
ACTIVE = "active"
|
||||
NO_CONFIGURE = "no-configure"
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
NO_PERMISSION = "no-permission"
|
||||
|
||||
|
||||
class SimpleModelProviderEntity(BaseModel):
|
||||
"""
|
||||
Simple provider.
|
||||
"""
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
supported_model_types: list[ModelType]
|
||||
|
||||
def __init__(self, provider_entity: ProviderEntity) -> None:
|
||||
"""
|
||||
Init simple provider.
|
||||
|
||||
:param provider_entity: provider entity
|
||||
"""
|
||||
super().__init__(
|
||||
provider=provider_entity.provider,
|
||||
label=provider_entity.label,
|
||||
icon_small=provider_entity.icon_small,
|
||||
icon_large=provider_entity.icon_large,
|
||||
supported_model_types=provider_entity.supported_model_types
|
||||
)
|
||||
|
||||
|
||||
class ModelWithProviderEntity(ProviderModel):
|
||||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
provider: SimpleModelProviderEntity
|
||||
status: ModelStatus
|
||||
|
||||
|
||||
class DefaultModelProviderEntity(BaseModel):
|
||||
"""
|
||||
Default model provider entity.
|
||||
"""
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
supported_model_types: list[ModelType]
|
||||
|
||||
|
||||
class DefaultModelEntity(BaseModel):
|
||||
"""
|
||||
Default model entity.
|
||||
"""
|
||||
model: str
|
||||
model_type: ModelType
|
||||
provider: DefaultModelProviderEntity
|
657
api/core/entities/provider_configuration.py
Normal file
657
api/core/entities/provider_configuration.py
Normal file
|
@ -0,0 +1,657 @@
|
|||
import datetime
|
||||
import json
|
||||
import time
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional, List, Dict, Tuple, Iterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
from core.model_runtime.utils import encoders
|
||||
from extensions.ext_database import db
|
||||
from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
|
||||
|
||||
|
||||
class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider configuration.
|
||||
"""
|
||||
tenant_id: str
|
||||
provider: ProviderEntity
|
||||
preferred_provider_type: ProviderType
|
||||
using_provider_type: ProviderType
|
||||
system_configuration: SystemConfiguration
|
||||
custom_configuration: CustomConfiguration
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
||||
"""
|
||||
Get current credentials.
|
||||
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
return self.system_configuration.credentials
|
||||
else:
|
||||
if self.custom_configuration.models:
|
||||
for model_configuration in self.custom_configuration.models:
|
||||
if model_configuration.model_type == model_type and model_configuration.model == model:
|
||||
return model_configuration.credentials
|
||||
|
||||
if self.custom_configuration.provider:
|
||||
return self.custom_configuration.provider.credentials
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_system_configuration_status(self) -> SystemConfigurationStatus:
|
||||
"""
|
||||
Get system configuration status.
|
||||
:return:
|
||||
"""
|
||||
if self.system_configuration.enabled is False:
|
||||
return SystemConfigurationStatus.UNSUPPORTED
|
||||
|
||||
current_quota_type = self.system_configuration.current_quota_type
|
||||
current_quota_configuration = next(
|
||||
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
|
||||
None
|
||||
)
|
||||
|
||||
return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
|
||||
SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
|
||||
def is_custom_configuration_available(self) -> bool:
|
||||
"""
|
||||
Check custom configuration available.
|
||||
:return:
|
||||
"""
|
||||
return (self.custom_configuration.provider is not None
|
||||
or len(self.custom_configuration.models) > 0)
|
||||
|
||||
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
Get custom credentials.
|
||||
|
||||
:param obfuscated: obfuscated secret data in credentials
|
||||
:return:
|
||||
"""
|
||||
if self.custom_configuration.provider is None:
|
||||
return None
|
||||
|
||||
credentials = self.custom_configuration.provider.credentials
|
||||
if not obfuscated:
|
||||
return credentials
|
||||
|
||||
# Obfuscate credentials
|
||||
return self._obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider_record = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
)
|
||||
|
||||
if provider_record:
|
||||
try:
|
||||
original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
# encrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == '[__HIDDEN__]' and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
model_provider_factory.provider_credentials_validate(
|
||||
self.provider.provider,
|
||||
credentials
|
||||
)
|
||||
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return provider_record, credentials
|
||||
|
||||
def add_or_update_custom_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Add or update custom provider credentials.
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
# validate custom provider config
|
||||
provider_record, credentials = self.custom_credentials_validate(credentials)
|
||||
|
||||
# save provider
|
||||
# Note: Do not switch the preferred provider, which allows users to use quotas first
|
||||
if provider_record:
|
||||
provider_record.encrypted_config = json.dumps(credentials)
|
||||
provider_record.is_valid = True
|
||||
provider_record.updated_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_record = Provider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
|
||||
self.switch_preferred_provider_type(ProviderType.CUSTOM)
|
||||
|
||||
def delete_custom_credentials(self) -> None:
|
||||
"""
|
||||
Delete custom provider credentials.
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider_record = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
# delete provider
|
||||
if provider_record:
|
||||
self.switch_preferred_provider_type(ProviderType.SYSTEM)
|
||||
|
||||
db.session.delete(provider_record)
|
||||
db.session.commit()
|
||||
|
||||
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
|
||||
-> Optional[dict]:
|
||||
"""
|
||||
Get custom model credentials.
|
||||
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param obfuscated: obfuscated secret data in credentials
|
||||
:return:
|
||||
"""
|
||||
if not self.custom_configuration.models:
|
||||
return None
|
||||
|
||||
for model_configuration in self.custom_configuration.models:
|
||||
if model_configuration.model_type == model_type and model_configuration.model == model:
|
||||
credentials = model_configuration.credentials
|
||||
if not obfuscated:
|
||||
return credentials
|
||||
|
||||
# Obfuscate credentials
|
||||
return self._obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
||||
-> Tuple[ProviderModel, dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
)
|
||||
|
||||
if provider_model_record:
|
||||
try:
|
||||
original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
# decrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == '[__HIDDEN__]' and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials
|
||||
)
|
||||
|
||||
model_schema = (
|
||||
model_provider_factory.get_provider_instance(self.provider.provider)
|
||||
.get_model_instance(model_type)._get_customizable_model_schema(
|
||||
model=model,
|
||||
credentials=credentials
|
||||
)
|
||||
)
|
||||
|
||||
if model_schema:
|
||||
credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
|
||||
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return provider_model_record, credentials
|
||||
|
||||
def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Add or update custom model credentials.
|
||||
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
# validate custom model config
|
||||
provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
|
||||
|
||||
# save provider model
|
||||
# Note: Do not switch the preferred provider, which allows users to use quotas first
|
||||
if provider_model_record:
|
||||
provider_model_record.encrypted_config = json.dumps(credentials)
|
||||
provider_model_record.is_valid = True
|
||||
provider_model_record.updated_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_model_record = ProviderModel(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider_model_record)
|
||||
db.session.commit()
|
||||
|
||||
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
|
||||
"""
|
||||
Delete custom model credentials.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
|
||||
# delete provider model
|
||||
if provider_model_record:
|
||||
db.session.delete(provider_model_record)
|
||||
db.session.commit()
|
||||
|
||||
def get_provider_instance(self) -> ModelProvider:
|
||||
"""
|
||||
Get provider instance.
|
||||
:return:
|
||||
"""
|
||||
return model_provider_factory.get_provider_instance(self.provider.provider)
|
||||
|
||||
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
|
||||
"""
|
||||
Get current model type instance.
|
||||
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# Get provider instance
|
||||
provider_instance = self.get_provider_instance()
|
||||
|
||||
# Get model instance of LLM
|
||||
return provider_instance.get_model_instance(model_type)
|
||||
|
||||
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
|
||||
"""
|
||||
Switch preferred provider type.
|
||||
:param provider_type:
|
||||
:return:
|
||||
"""
|
||||
if provider_type == self.preferred_provider_type:
|
||||
return
|
||||
|
||||
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
|
||||
return
|
||||
|
||||
# get preferred provider
|
||||
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == self.provider.provider
|
||||
).first()
|
||||
|
||||
if preferred_model_provider:
|
||||
preferred_model_provider.preferred_provider_type = provider_type.value
|
||||
else:
|
||||
preferred_model_provider = TenantPreferredModelProvider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
preferred_provider_type=provider_type.value
|
||||
)
|
||||
db.session.add(preferred_model_provider)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
||||
"""
|
||||
Extract secret input form variables.
|
||||
|
||||
:param credential_form_schemas:
|
||||
:return:
|
||||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
||||
"""
|
||||
Obfuscated credentials.
|
||||
|
||||
:param credentials: credentials
|
||||
:param credential_form_schemas: credential form schemas
|
||||
:return:
|
||||
"""
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self._extract_secret_variables(
|
||||
credential_form_schemas
|
||||
)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.obfuscated_token(value)
|
||||
|
||||
return copy_credentials
|
||||
|
||||
def get_provider_model(self, model_type: ModelType,
|
||||
model: str,
|
||||
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get provider model.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param only_active: return active model only
|
||||
:return:
|
||||
"""
|
||||
provider_models = self.get_provider_models(model_type, only_active)
|
||||
|
||||
for provider_model in provider_models:
|
||||
if provider_model.model == model:
|
||||
return provider_model
|
||||
|
||||
return None
|
||||
|
||||
def get_provider_models(self, model_type: Optional[ModelType] = None,
|
||||
only_active: bool = False) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get provider models.
|
||||
:param model_type: model type
|
||||
:param only_active: only active models
|
||||
:return:
|
||||
"""
|
||||
provider_instance = self.get_provider_instance()
|
||||
|
||||
model_types = []
|
||||
if model_type:
|
||||
model_types.append(model_type)
|
||||
else:
|
||||
model_types = provider_instance.get_provider_schema().supported_model_types
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
provider_models = self._get_system_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance
|
||||
)
|
||||
else:
|
||||
provider_models = self._get_custom_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance
|
||||
)
|
||||
|
||||
if only_active:
|
||||
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
|
||||
|
||||
# resort provider_models
|
||||
return sorted(provider_models, key=lambda x: x.model_type.value)
|
||||
|
||||
def _get_system_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get system provider models.
|
||||
|
||||
:param model_types: model types
|
||||
:param provider_instance: provider instance
|
||||
:return:
|
||||
"""
|
||||
provider_models = []
|
||||
for model_type in model_types:
|
||||
provider_models.extend(
|
||||
[
|
||||
ModelWithProviderEntity(
|
||||
**m.dict(),
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
)
|
||||
for m in provider_instance.models(model_type)
|
||||
]
|
||||
)
|
||||
|
||||
for quota_configuration in self.system_configuration.quota_configurations:
|
||||
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
|
||||
continue
|
||||
|
||||
restrict_llms = quota_configuration.restrict_llms
|
||||
if not restrict_llms:
|
||||
break
|
||||
|
||||
# if llm name not in restricted llm list, remove it
|
||||
for m in provider_models:
|
||||
if m.model_type == ModelType.LLM and m.model not in restrict_llms:
|
||||
m.status = ModelStatus.NO_PERMISSION
|
||||
elif not quota_configuration.is_valid:
|
||||
m.status = ModelStatus.QUOTA_EXCEEDED
|
||||
|
||||
return provider_models
|
||||
|
||||
def _get_custom_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get custom provider models.
|
||||
|
||||
:param model_types: model types
|
||||
:param provider_instance: provider instance
|
||||
:return:
|
||||
"""
|
||||
provider_models = []
|
||||
|
||||
credentials = None
|
||||
if self.custom_configuration.provider:
|
||||
credentials = self.custom_configuration.provider.credentials
|
||||
|
||||
for model_type in model_types:
|
||||
if model_type not in self.provider.supported_model_types:
|
||||
continue
|
||||
|
||||
models = provider_instance.models(model_type)
|
||||
for m in models:
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
**m.dict(),
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
||||
)
|
||||
)
|
||||
|
||||
# custom models
|
||||
for model_configuration in self.custom_configuration.models:
|
||||
if model_configuration.model_type not in model_types:
|
||||
continue
|
||||
|
||||
custom_model_schema = (
|
||||
provider_instance.get_model_instance(model_configuration.model_type)
|
||||
.get_customizable_model_schema_from_credentials(
|
||||
model_configuration.model,
|
||||
model_configuration.credentials
|
||||
)
|
||||
)
|
||||
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
**custom_model_schema.dict(),
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
)
|
||||
)
|
||||
|
||||
return provider_models
|
||||
|
||||
|
||||
class ProviderConfigurations(BaseModel):
|
||||
"""
|
||||
Model class for provider configuration dict.
|
||||
"""
|
||||
tenant_id: str
|
||||
configurations: Dict[str, ProviderConfiguration] = {}
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
super().__init__(tenant_id=tenant_id)
|
||||
|
||||
def get_models(self,
|
||||
provider: Optional[str] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
only_active: bool = False) \
|
||||
-> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get available models.
|
||||
|
||||
If preferred provider type is `system`:
|
||||
Get the current **system mode** if provider supported,
|
||||
if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
|
||||
If there is no model configured in custom mode, it is treated as no_configure.
|
||||
system > custom > no_configure
|
||||
|
||||
If preferred provider type is `custom`:
|
||||
If custom credentials are configured, it is treated as custom mode.
|
||||
Otherwise, get the current **system mode** if supported,
|
||||
If all system modes are not available (no quota), it is treated as no_configure.
|
||||
custom > system > no_configure
|
||||
|
||||
If real mode is `system`, use system credentials to get models,
|
||||
paid quotas > provider free quotas > system free quotas
|
||||
include pre-defined models (exclude GPT-4, status marked as `no_permission`).
|
||||
If real mode is `custom`, use workspace custom credentials to get models,
|
||||
include pre-defined models, custom models(manual append).
|
||||
If real mode is `no_configure`, only return pre-defined models from `model runtime`.
|
||||
(model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
|
||||
model status marked as `active` is available.
|
||||
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param only_active: only active models
|
||||
:return:
|
||||
"""
|
||||
all_models = []
|
||||
for provider_configuration in self.values():
|
||||
if provider and provider_configuration.provider.provider != provider:
|
||||
continue
|
||||
|
||||
all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
|
||||
|
||||
return all_models
|
||||
|
||||
def to_list(self) -> List[ProviderConfiguration]:
|
||||
"""
|
||||
Convert to list.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return list(self.values())
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.configurations[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.configurations[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.configurations)
|
||||
|
||||
def values(self) -> Iterator[ProviderConfiguration]:
|
||||
return self.configurations.values()
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.configurations.get(key, default)
|
||||
|
||||
|
||||
class ProviderModelBundle(BaseModel):
|
||||
"""
|
||||
Provider model bundle.
|
||||
"""
|
||||
configuration: ProviderConfiguration
|
||||
provider_instance: ModelProvider
|
||||
model_type_instance: AIModel
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
67
api/core/entities/provider_entities.py
Normal file
67
api/core/entities/provider_entities.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderQuotaType
|
||||
|
||||
|
||||
class QuotaUnit(Enum):
|
||||
TIMES = 'times'
|
||||
TOKENS = 'tokens'
|
||||
|
||||
|
||||
class SystemConfigurationStatus(Enum):
|
||||
"""
|
||||
Enum class for system configuration status.
|
||||
"""
|
||||
ACTIVE = 'active'
|
||||
QUOTA_EXCEEDED = 'quota-exceeded'
|
||||
UNSUPPORTED = 'unsupported'
|
||||
|
||||
|
||||
class QuotaConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider quota configuration.
|
||||
"""
|
||||
quota_type: ProviderQuotaType
|
||||
quota_unit: QuotaUnit
|
||||
quota_limit: int
|
||||
quota_used: int
|
||||
is_valid: bool
|
||||
restrict_llms: list[str] = []
|
||||
|
||||
|
||||
class SystemConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration.
|
||||
"""
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
credentials: Optional[dict] = None
|
||||
|
||||
|
||||
class CustomProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration.
|
||||
"""
|
||||
credentials: dict
|
||||
|
||||
|
||||
class CustomModelConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider custom model configuration.
|
||||
"""
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
|
||||
|
||||
class CustomConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration.
|
||||
"""
|
||||
provider: Optional[CustomProviderConfiguration] = None
|
||||
models: list[CustomModelConfiguration] = []
|
118
api/core/entities/queue_entities.py
Normal file
118
api/core/entities/queue_entities.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
|
||||
|
||||
class QueueEvent(Enum):
|
||||
"""
|
||||
QueueEvent enum
|
||||
"""
|
||||
MESSAGE = "message"
|
||||
MESSAGE_REPLACE = "message-replace"
|
||||
MESSAGE_END = "message-end"
|
||||
RETRIEVER_RESOURCES = "retriever-resources"
|
||||
ANNOTATION_REPLY = "annotation-reply"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
"""
|
||||
QueueEvent entity
|
||||
"""
|
||||
event: QueueEvent
|
||||
|
||||
|
||||
class QueueMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageEvent entity
|
||||
"""
|
||||
event = QueueEvent.MESSAGE
|
||||
chunk: LLMResultChunk
|
||||
|
||||
|
||||
class QueueMessageReplaceEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageReplaceEvent entity
|
||||
"""
|
||||
event = QueueEvent.MESSAGE_REPLACE
|
||||
text: str
|
||||
|
||||
|
||||
class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueRetrieverResourcesEvent entity
|
||||
"""
|
||||
event = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: list[dict]
|
||||
|
||||
|
||||
class AnnotationReplyEvent(AppQueueEvent):
|
||||
"""
|
||||
AnnotationReplyEvent entity
|
||||
"""
|
||||
event = QueueEvent.ANNOTATION_REPLY
|
||||
message_annotation_id: str
|
||||
|
||||
|
||||
class QueueMessageEndEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageEndEvent entity
|
||||
"""
|
||||
event = QueueEvent.MESSAGE_END
|
||||
llm_result: LLMResult
|
||||
|
||||
|
||||
class QueueAgentThoughtEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAgentThoughtEvent entity
|
||||
"""
|
||||
event = QueueEvent.AGENT_THOUGHT
|
||||
agent_thought_id: str
|
||||
|
||||
|
||||
class QueueErrorEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueErrorEvent entity
|
||||
"""
|
||||
event = QueueEvent.ERROR
|
||||
error: Any
|
||||
|
||||
|
||||
class QueuePingEvent(AppQueueEvent):
|
||||
"""
|
||||
QueuePingEvent entity
|
||||
"""
|
||||
event = QueueEvent.PING
|
||||
|
||||
|
||||
class QueueStopEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueStopEvent entity
|
||||
"""
|
||||
class StopBy(Enum):
|
||||
"""
|
||||
Stop by enum
|
||||
"""
|
||||
USER_MANUAL = "user-manual"
|
||||
ANNOTATION_REPLY = "annotation-reply"
|
||||
OUTPUT_MODERATION = "output-moderation"
|
||||
|
||||
event = QueueEvent.STOP
|
||||
stopped_by: StopBy
|
||||
|
||||
|
||||
class QueueMessage(BaseModel):
|
||||
"""
|
||||
QueueMessage entity
|
||||
"""
|
||||
task_id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
app_mode: str
|
||||
event: AppQueueEvent
|
|
@ -14,26 +14,6 @@ class LLMBadRequestError(LLMError):
|
|||
description = "Bad Request"
|
||||
|
||||
|
||||
class LLMAPIConnectionError(LLMError):
|
||||
"""Raised when the LLM returns API connection error."""
|
||||
description = "API Connection Error"
|
||||
|
||||
|
||||
class LLMAPIUnavailableError(LLMError):
|
||||
"""Raised when the LLM returns API unavailable error."""
|
||||
description = "API Unavailable Error"
|
||||
|
||||
|
||||
class LLMRateLimitError(LLMError):
|
||||
"""Raised when the LLM returns rate limit error."""
|
||||
description = "Rate Limit Error"
|
||||
|
||||
|
||||
class LLMAuthorizationError(LLMError):
|
||||
"""Raised when the LLM returns authorization error."""
|
||||
description = "Authorization Error"
|
||||
|
||||
|
||||
class ProviderTokenNotInitError(Exception):
|
||||
"""
|
||||
Custom exception raised when the provider token is not initialized.
|
35
api/core/external_data_tool/weather_search/schema.json
Normal file
35
api/core/external_data_tool/weather_search/schema.json
Normal file
|
@ -0,0 +1,35 @@
|
|||
{
|
||||
"label": {
|
||||
"en-US": "Weather Search",
|
||||
"zh-Hans": "天气查询"
|
||||
},
|
||||
"form_schema": [
|
||||
{
|
||||
"type": "select",
|
||||
"label": {
|
||||
"en-US": "Temperature Unit",
|
||||
"zh-Hans": "温度单位"
|
||||
},
|
||||
"variable": "temperature_unit",
|
||||
"required": true,
|
||||
"options": [
|
||||
{
|
||||
"label": {
|
||||
"en-US": "Fahrenheit",
|
||||
"zh-Hans": "华氏度"
|
||||
},
|
||||
"value": "fahrenheit"
|
||||
},
|
||||
{
|
||||
"label": {
|
||||
"en-US": "Centigrade",
|
||||
"zh-Hans": "摄氏度"
|
||||
},
|
||||
"value": "centigrade"
|
||||
}
|
||||
],
|
||||
"default": "centigrade",
|
||||
"placeholder": "Please select temperature unit"
|
||||
}
|
||||
]
|
||||
}
|
45
api/core/external_data_tool/weather_search/weather_search.py
Normal file
45
api/core/external_data_tool/weather_search/weather_search.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
|
||||
|
||||
class WeatherSearch(ExternalDataTool):
|
||||
"""
|
||||
The name of custom type must be unique, keep the same with directory and file name.
|
||||
"""
|
||||
name: str = "weather_search"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
schema.json validation. It will be called when user save the config.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
config = {
|
||||
"temperature_unit": "centigrade"
|
||||
}
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the variables of form config
|
||||
:return:
|
||||
"""
|
||||
|
||||
if not config.get('temperature_unit'):
|
||||
raise ValueError('temperature unit is required')
|
||||
|
||||
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: the query of chat app
|
||||
:return: the tool query result
|
||||
"""
|
||||
city = inputs.get('city')
|
||||
temperature_unit = self.config.get('temperature_unit')
|
||||
|
||||
if temperature_unit == 'fahrenheit':
|
||||
return f'Weather in {city} is 32°F'
|
||||
else:
|
||||
return f'Weather in {city} is 0°C'
|
325
api/core/features/agent_runner.py
Normal file
325
api/core/features/agent_runner.py
Normal file
|
@ -0,0 +1,325 @@
|
|||
import logging
|
||||
from typing import cast, Optional, List
|
||||
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.tools import BaseTool, WikipediaQueryRun, Tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \
|
||||
AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tool.current_datetime_tool import DatetimeTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
|
||||
from core.tool.web_reader_tool import WebReaderTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRunnerFeature:
|
||||
def __init__(self, tenant_id: str,
|
||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||
model_config: ModelConfigEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
agent_llm_callback: AgentLLMCallback,
|
||||
callback: AgentLoopGatherCallbackHandler,
|
||||
memory: Optional[TokenBufferMemory] = None,) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config: app orchestration config
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param agent_llm_callback: agent llm callback
|
||||
:param callback: callback
|
||||
:param memory: memory
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.app_orchestration_config = app_orchestration_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
self.queue_manager = queue_manager
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.agent_llm_callback = agent_llm_callback
|
||||
self.callback = callback
|
||||
self.memory = memory
|
||||
|
||||
def run(self, query: str,
|
||||
invoke_from: InvokeFrom) -> Optional[str]:
|
||||
"""
|
||||
Retrieve agent loop result.
|
||||
:param query: query
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
provider = self.config.provider
|
||||
model = self.config.model
|
||||
tool_configs = self.config.tools
|
||||
|
||||
# check model is support tool calling
|
||||
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model,
|
||||
credentials=self.model_config.credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None
|
||||
|
||||
planning_strategy = PlanningStrategy.REACT
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.FUNCTION_CALL
|
||||
|
||||
tools = self.to_tools(
|
||||
tool_configs=tool_configs,
|
||||
invoke_from=invoke_from,
|
||||
callbacks=[self.callback, DifyStdOutCallbackHandler()],
|
||||
)
|
||||
|
||||
if len(tools) == 0:
|
||||
return None
|
||||
|
||||
agent_configuration = AgentConfiguration(
|
||||
strategy=planning_strategy,
|
||||
model_config=self.model_config,
|
||||
tools=tools,
|
||||
memory=self.memory,
|
||||
max_iterations=10,
|
||||
max_execution_time=400.0,
|
||||
early_stopping_method="generate",
|
||||
agent_llm_callback=self.agent_llm_callback,
|
||||
callbacks=[self.callback, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
agent_executor = AgentExecutor(agent_configuration)
|
||||
|
||||
try:
|
||||
# check if should use agent
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if not should_use_agent:
|
||||
return None
|
||||
|
||||
result = agent_executor.run(query)
|
||||
return result.output
|
||||
except Exception as ex:
|
||||
logger.exception("agent_executor run failed")
|
||||
return None
|
||||
|
||||
def to_tools(self, tool_configs: list[AgentToolEntity],
|
||||
invoke_from: InvokeFrom,
|
||||
callbacks: list[BaseCallbackHandler]) \
|
||||
-> Optional[List[BaseTool]]:
|
||||
"""
|
||||
Convert tool configs to tools
|
||||
:param tool_configs: tool configs
|
||||
:param invoke_from: invoke from
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
tools = []
|
||||
for tool_config in tool_configs:
|
||||
tool = None
|
||||
if tool_config.tool_id == "dataset":
|
||||
tool = self.to_dataset_retriever_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "web_reader":
|
||||
tool = self.to_web_reader_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "google_search":
|
||||
tool = self.to_google_search_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "wikipedia":
|
||||
tool = self.to_wikipedia_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "current_datetime":
|
||||
tool = self.to_current_datetime_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
if tool:
|
||||
if tool.callbacks is not None:
|
||||
tool.callbacks.extend(callbacks)
|
||||
else:
|
||||
tool.callbacks = callbacks
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
def to_dataset_retriever_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) \
|
||||
-> Optional[BaseTool]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
"""
|
||||
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=self.queue_manager,
|
||||
app_id=self.message.app_id,
|
||||
message_id=self.message.id,
|
||||
user_id=self.user_id,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == tool_config.get("id")
|
||||
).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
return None
|
||||
|
||||
# get retrieval model config
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
|
||||
# get score threshold
|
||||
score_threshold = None
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=show_retrieve_source,
|
||||
retriever_from=invoke_from.to_source()
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_web_reader_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for reading web pages
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
model_parameters = {
|
||||
"temperature": 0,
|
||||
"max_tokens": 500
|
||||
}
|
||||
|
||||
tool = WebReaderTool(
|
||||
model_config=self.model_config,
|
||||
model_parameters=model_parameters,
|
||||
max_chunk_length=4000,
|
||||
continue_reading=True
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_google_search_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for performing a Google search and extracting snippets and webpages
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
|
||||
func_kwargs = tool_provider.credentials_to_func_kwargs()
|
||||
if not func_kwargs:
|
||||
return None
|
||||
|
||||
tool = Tool(
|
||||
name="google_search",
|
||||
description="A tool for performing a Google search and extracting snippets and webpages "
|
||||
"when you need to search for something you don't know or when your information "
|
||||
"is not up to date. "
|
||||
"Input should be a search query.",
|
||||
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
|
||||
args_schema=OptimizedSerpAPIInput
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_current_datetime_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for getting the current date and time
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
return DatetimeTool()
|
||||
|
||||
def to_wikipedia_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for searching Wikipedia
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
class WikipediaInput(BaseModel):
|
||||
query: str = Field(..., description="search query.")
|
||||
|
||||
return WikipediaQueryRun(
|
||||
name="wikipedia",
|
||||
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
|
||||
args_schema=WikipediaInput
|
||||
)
|
119
api/core/features/annotation_reply.py
Normal file
119
api/core/features/annotation_reply.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, Message, AppAnnotationSetting, MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnotationReplyFeature:
|
||||
def query(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param user_id: user id
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_record.id).first()
|
||||
|
||||
if not annotation_setting:
|
||||
return None
|
||||
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
|
||||
try:
|
||||
score_threshold = annotation_setting.score_threshold or 1
|
||||
embedding_provider_name = collection_binding_detail.provider_name
|
||||
embedding_model_name = collection_binding_detail.model_name
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=app_record.tenant_id,
|
||||
provider=embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=embedding_model_name
|
||||
)
|
||||
|
||||
# get embedding model
|
||||
embeddings = CacheEmbedding(model_instance)
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
'annotation'
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_record.id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings,
|
||||
attributes=['doc_id', 'annotation_id', 'app_id']
|
||||
)
|
||||
|
||||
documents = vector_index.search(
|
||||
query=query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': 1,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if documents:
|
||||
annotation_id = documents[0].metadata['annotation_id']
|
||||
score = documents[0].metadata['score']
|
||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
if annotation:
|
||||
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
|
||||
from_source = 'api'
|
||||
else:
|
||||
from_source = 'console'
|
||||
|
||||
# insert annotation history
|
||||
AppAnnotationService.add_annotation_history(annotation.id,
|
||||
app_record.id,
|
||||
annotation.question,
|
||||
annotation.content,
|
||||
query,
|
||||
user_id,
|
||||
message.id,
|
||||
from_source,
|
||||
score)
|
||||
|
||||
return annotation
|
||||
except Exception as e:
|
||||
logger.warning(f'Query annotation failed, exception: {str(e)}.')
|
||||
return None
|
||||
|
||||
return None
|
181
api/core/features/dataset_retrieval.py
Normal file
181
api/core/features/dataset_retrieval.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
from typing import cast, Optional, List
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import DatasetEntity, ModelConfigEntity, InvokeFrom, DatasetRetrieveConfigEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class DatasetRetrievalFeature:
|
||||
def retrieve(self, tenant_id: str,
|
||||
model_config: ModelConfigEntity,
|
||||
config: DatasetEntity,
|
||||
query: str,
|
||||
invoke_from: InvokeFrom,
|
||||
show_retrieve_source: bool,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dataset.
|
||||
:param tenant_id: tenant id
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param query: query
|
||||
:param invoke_from: invoke from
|
||||
:param show_retrieve_source: show retrieve source
|
||||
:param hit_callback: hit callback
|
||||
:param memory: memory
|
||||
:return:
|
||||
"""
|
||||
dataset_ids = config.dataset_ids
|
||||
retrieve_config = config.retrieve_config
|
||||
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None
|
||||
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
|
||||
dataset_retriever_tools = self.to_dataset_retriever_tool(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=retrieve_config,
|
||||
return_resource=show_retrieve_source,
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback
|
||||
)
|
||||
|
||||
if len(dataset_retriever_tools) == 0:
|
||||
return None
|
||||
|
||||
agent_configuration = AgentConfiguration(
|
||||
strategy=planning_strategy,
|
||||
model_config=model_config,
|
||||
tools=dataset_retriever_tools,
|
||||
memory=memory,
|
||||
max_iterations=10,
|
||||
max_execution_time=400.0,
|
||||
early_stopping_method="generate"
|
||||
)
|
||||
|
||||
agent_executor = AgentExecutor(agent_configuration)
|
||||
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if not should_use_agent:
|
||||
return None
|
||||
|
||||
result = agent_executor.run(query)
|
||||
|
||||
return result.output
|
||||
|
||||
def to_dataset_retriever_tool(self, tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler) \
|
||||
-> Optional[List[BaseTool]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param tenant_id: tenant id
|
||||
:param dataset_ids: dataset ids
|
||||
:param retrieve_config: retrieve config
|
||||
:param return_resource: return resource
|
||||
:param invoke_from: invoke from
|
||||
:param hit_callback: hit callback
|
||||
"""
|
||||
tools = []
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# get retrieval model config
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
for dataset in available_datasets:
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
|
||||
# get score threshold
|
||||
score_threshold = None
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source()
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
score_threshold=(retrieve_config.score_threshold or 0.5)
|
||||
if retrieve_config.reranking_model.get('score_threshold_enabled', False) else None,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
|
||||
reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
96
api/core/features/external_data_fetch.py
Normal file
96
api/core/features/external_data_fetch.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from flask import current_app, Flask
|
||||
|
||||
from core.entities.application_entities import ExternalDataVariableEntity
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExternalDataFetchFeature:
|
||||
def fetch(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
# Group tools by type and config
|
||||
grouped_tools = {}
|
||||
for tool in external_data_tools:
|
||||
tool_key = (tool.type, json.dumps(tool.config, sort_keys=True))
|
||||
grouped_tools.setdefault(tool_key, []).append(tool)
|
||||
|
||||
results = {}
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = {}
|
||||
for tool in external_data_tools:
|
||||
future = executor.submit(
|
||||
self._query_external_data_tool,
|
||||
current_app._get_current_object(),
|
||||
tenant_id,
|
||||
app_id,
|
||||
tool,
|
||||
inputs,
|
||||
query
|
||||
)
|
||||
|
||||
futures[future] = tool
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
tool_variable, result = future.result()
|
||||
results[tool_variable] = result
|
||||
|
||||
inputs.update(results)
|
||||
return inputs
|
||||
|
||||
def _query_external_data_tool(self, flask_app: Flask,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tool: ExternalDataVariableEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Query external data tool.
|
||||
:param flask_app: flask app
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param external_data_tool: external data tool
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
tool_variable = external_data_tool.variable
|
||||
tool_type = external_data_tool.type
|
||||
tool_config = external_data_tool.config
|
||||
|
||||
external_data_tool_factory = ExternalDataToolFactory(
|
||||
name=tool_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
variable=tool_variable,
|
||||
config=tool_config
|
||||
)
|
||||
|
||||
# query external data tool
|
||||
result = external_data_tool_factory.query(
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
return tool_variable, result
|
32
api/core/features/hosting_moderation.py
Normal file
32
api/core/features/hosting_moderation.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import logging
|
||||
|
||||
from core.entities.application_entities import ApplicationGenerateEntity
|
||||
from core.helper import moderation
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HostingModerationFeature:
|
||||
def check(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
model_config = app_orchestration_config.model_config
|
||||
|
||||
text = ""
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message.content, str):
|
||||
text += prompt_message.content + "\n"
|
||||
|
||||
moderation_result = moderation.check_moderation(
|
||||
model_config,
|
||||
text
|
||||
)
|
||||
|
||||
return moderation_result
|
50
api/core/features/moderation.py
Normal file
50
api/core/features/moderation.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from core.entities.application_entities import AppOrchestrationConfigEntity
|
||||
from core.moderation.base import ModerationAction, ModerationException
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModerationFeature:
|
||||
def check(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config_entity: app orchestration config entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
if not app_orchestration_config_entity.sensitive_word_avoidance:
|
||||
return False, inputs, query
|
||||
|
||||
sensitive_word_avoidance_config = app_orchestration_config_entity.sensitive_word_avoidance
|
||||
moderation_type = sensitive_word_avoidance_config.type
|
||||
|
||||
moderation_factory = ModerationFactory(
|
||||
name=moderation_type,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
config=sensitive_word_avoidance_config.config
|
||||
)
|
||||
|
||||
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
return False, inputs, query
|
||||
|
||||
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
raise ModerationException(moderation_result.preset_response)
|
||||
elif moderation_result.action == ModerationAction.OVERRIDED:
|
||||
inputs = moderation_result.inputs
|
||||
query = moderation_result.query
|
||||
|
||||
return True, inputs, query
|
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||
from pydantic import BaseModel
|
||||
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
|
||||
|
@ -50,14 +50,14 @@ class FileObj(BaseModel):
|
|||
return self._get_data(force_url=True)
|
||||
|
||||
@property
|
||||
def prompt_message_file(self) -> PromptMessageFile:
|
||||
def prompt_message_content(self) -> ImagePromptMessageContent:
|
||||
if self.type == FileType.IMAGE:
|
||||
image_config = self.file_config.get('image')
|
||||
|
||||
return ImagePromptMessageFile(
|
||||
return ImagePromptMessageContent(
|
||||
data=self.data,
|
||||
detail=ImagePromptMessageFile.DETAIL.HIGH
|
||||
if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH
|
||||
if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
|
||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
||||
|
|
|
@ -3,10 +3,10 @@ import logging
|
|||
|
||||
from langchain.schema import OutputParserException
|
||||
|
||||
from core.model_providers.error import LLMError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage, SystemPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
|
||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
|
@ -26,17 +26,22 @@ class LLMGenerator:
|
|||
|
||||
prompt += query + "\n"
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=1,
|
||||
max_tokens=100
|
||||
)
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
prompts = [UserPromptMessage(content=prompt)]
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompts,
|
||||
model_parameters={
|
||||
"max_tokens": 100,
|
||||
"temperature": 1
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
answer = response.message.content
|
||||
|
||||
result_dict = json.loads(answer)
|
||||
answer = result_dict['Your Output']
|
||||
|
@ -62,22 +67,28 @@ class LLMGenerator:
|
|||
})
|
||||
|
||||
try:
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=256,
|
||||
temperature=0
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
except InvokeAuthorizationError:
|
||||
return []
|
||||
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
output = model_instance.run(prompt_messages)
|
||||
questions = output_parser.parse(output.content)
|
||||
except LLMError:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={
|
||||
"max_tokens": 256,
|
||||
"temperature": 0
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
|
||||
questions = output_parser.parse(response.message.content)
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
@ -105,20 +116,26 @@ class LLMGenerator:
|
|||
remove_template_variables=False
|
||||
)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=512,
|
||||
temperature=0
|
||||
)
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
output = model_instance.run(prompt_messages)
|
||||
rule_config = output_parser.parse(output.content)
|
||||
except LLMError as e:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={
|
||||
"max_tokens": 512,
|
||||
"temperature": 0
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
|
||||
rule_config = output_parser.parse(response.message.content)
|
||||
except InvokeError as e:
|
||||
raise e
|
||||
except OutputParserException:
|
||||
raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
|
||||
|
@ -136,18 +153,24 @@ class LLMGenerator:
|
|||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=2000
|
||||
)
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
PromptMessage(content=prompt, type=MessageType.SYSTEM),
|
||||
PromptMessage(content=query)
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=prompt),
|
||||
UserPromptMessage(content=query)
|
||||
]
|
||||
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={
|
||||
"max_tokens": 2000
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
|
||||
answer = response.message.content
|
||||
return answer.strip()
|
||||
|
|
|
@ -18,3 +18,17 @@ def encrypt_token(tenant_id: str, token: str):
|
|||
|
||||
def decrypt_token(tenant_id: str, token: str):
|
||||
return rsa.decrypt(base64.b64decode(token), tenant_id)
|
||||
|
||||
|
||||
def batch_decrypt_token(tenant_id: str, tokens: list[str]):
|
||||
rsa_key, cipher_rsa = rsa.get_decrypt_decoding(tenant_id)
|
||||
|
||||
return [rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) for token in tokens]
|
||||
|
||||
|
||||
def get_decrypt_decoding(tenant_id: str):
|
||||
return rsa.get_decrypt_decoding(tenant_id)
|
||||
|
||||
|
||||
def decrypt_token_with_decoding(token: str, rsa_key, cipher_rsa):
|
||||
return rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa)
|
||||
|
|
22
api/core/helper/lru_cache.py
Normal file
22
api/core/helper/lru_cache.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, capacity: int):
|
||||
self.cache = OrderedDict()
|
||||
self.capacity = capacity
|
||||
|
||||
def get(self, key: Any) -> Any:
|
||||
if key not in self.cache:
|
||||
return None
|
||||
else:
|
||||
self.cache.move_to_end(key) # move the key to the end of the OrderedDict
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key: Any, value: Any) -> None:
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
self.cache[key] = value
|
||||
if len(self.cache) > self.capacity:
|
||||
self.cache.popitem(last=False) # pop the first item
|
|
@ -1,18 +1,27 @@
|
|||
import logging
|
||||
import random
|
||||
|
||||
import openai
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_moderation(model_config: ModelConfigEntity, text: str) -> bool:
|
||||
moderation_config = hosting_configuration.moderation_config
|
||||
if (moderation_config and moderation_config.enabled is True
|
||||
and 'openai' in hosting_configuration.provider_map
|
||||
and hosting_configuration.provider_map['openai'].enabled is True
|
||||
):
|
||||
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
|
||||
provider_name = model_config.provider
|
||||
if using_provider_type == ProviderType.SYSTEM \
|
||||
and provider_name in moderation_config.providers:
|
||||
hosting_openai_config = hosting_configuration.provider_map['openai']
|
||||
|
||||
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
||||
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
|
||||
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and model_provider.provider_name in hosted_config.moderation.providers:
|
||||
# 2000 text per chunk
|
||||
length = 2000
|
||||
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||
|
@ -23,14 +32,17 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
|||
text_chunk = random.choice(text_chunks)
|
||||
|
||||
try:
|
||||
moderation_result = openai.Moderation.create(input=text_chunk,
|
||||
api_key=hosted_model_providers.openai.api_key)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
model_type_instance = OpenAIModerationModel()
|
||||
moderation_result = model_type_instance.invoke(
|
||||
model='text-moderation-stable',
|
||||
credentials=hosting_openai_config.credentials,
|
||||
text=text_chunk
|
||||
)
|
||||
|
||||
if moderation_result is True:
|
||||
return True
|
||||
except Exception as ex:
|
||||
logger.exception(ex)
|
||||
raise InvokeBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
return False
|
||||
|
|
213
api/core/hosting_configuration.py
Normal file
213
api/core/hosting_configuration.py
Normal file
|
@ -0,0 +1,213 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from models.provider import ProviderQuotaType
|
||||
|
||||
|
||||
class HostingQuota(BaseModel):
|
||||
quota_type: ProviderQuotaType
|
||||
restrict_llms: list[str] = []
|
||||
|
||||
|
||||
class TrialHostingQuota(HostingQuota):
|
||||
quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the hosting provider models. -1 means unlimited."""
|
||||
|
||||
|
||||
class PaidHostingQuota(HostingQuota):
|
||||
quota_type: ProviderQuotaType = ProviderQuotaType.PAID
|
||||
stripe_price_id: str = None
|
||||
increase_quota: int = 1
|
||||
min_quantity: int = 20
|
||||
max_quantity: int = 100
|
||||
|
||||
|
||||
class FreeHostingQuota(HostingQuota):
|
||||
quota_type: ProviderQuotaType = ProviderQuotaType.FREE
|
||||
|
||||
|
||||
class HostingProvider(BaseModel):
|
||||
enabled: bool = False
|
||||
credentials: Optional[dict] = None
|
||||
quota_unit: Optional[QuotaUnit] = None
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
|
||||
class HostedModerationConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
providers: list[str] = []
|
||||
|
||||
|
||||
class HostingConfiguration:
|
||||
provider_map: dict[str, HostingProvider] = {}
|
||||
moderation_config: HostedModerationConfig = None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
if app.config.get('EDITION') != 'CLOUD':
|
||||
return
|
||||
|
||||
self.provider_map["openai"] = self.init_openai()
|
||||
self.provider_map["anthropic"] = self.init_anthropic()
|
||||
self.provider_map["minimax"] = self.init_minimax()
|
||||
self.provider_map["spark"] = self.init_spark()
|
||||
self.provider_map["zhipuai"] = self.init_zhipuai()
|
||||
|
||||
self.moderation_config = self.init_moderation_config()
|
||||
|
||||
def init_openai(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TIMES
|
||||
if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
|
||||
credentials = {
|
||||
"openai_api_key": os.environ.get("HOSTED_OPENAI_API_KEY"),
|
||||
}
|
||||
|
||||
if os.environ.get("HOSTED_OPENAI_API_BASE"):
|
||||
credentials["openai_api_base"] = os.environ.get("HOSTED_OPENAI_API_BASE")
|
||||
|
||||
if os.environ.get("HOSTED_OPENAI_API_ORGANIZATION"):
|
||||
credentials["openai_organization"] = os.environ.get("HOSTED_OPENAI_API_ORGANIZATION")
|
||||
|
||||
quotas = []
|
||||
hosted_quota_limit = int(os.environ.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
|
||||
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_llms=[
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"text-davinci-003"
|
||||
]
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if os.environ.get("HOSTED_OPENAI_PAID_ENABLED") and os.environ.get(
|
||||
"HOSTED_OPENAI_PAID_ENABLED").lower() == 'true':
|
||||
paid_quota = PaidHostingQuota(
|
||||
stripe_price_id=os.environ.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
|
||||
increase_quota=int(os.environ.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")),
|
||||
min_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")),
|
||||
max_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1"))
|
||||
)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
credentials=credentials,
|
||||
quota_unit=quota_unit,
|
||||
quotas=quotas
|
||||
)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_anthropic(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if os.environ.get("HOSTED_ANTHROPIC_ENABLED") and os.environ.get("HOSTED_ANTHROPIC_ENABLED").lower() == 'true':
|
||||
credentials = {
|
||||
"anthropic_api_key": os.environ.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||
}
|
||||
|
||||
if os.environ.get("HOSTED_ANTHROPIC_API_BASE"):
|
||||
credentials["anthropic_api_url"] = os.environ.get("HOSTED_ANTHROPIC_API_BASE")
|
||||
|
||||
quotas = []
|
||||
hosted_quota_limit = int(os.environ.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
|
||||
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if os.environ.get("HOSTED_ANTHROPIC_PAID_ENABLED") and os.environ.get(
|
||||
"HOSTED_ANTHROPIC_PAID_ENABLED").lower() == 'true':
|
||||
paid_quota = PaidHostingQuota(
|
||||
stripe_price_id=os.environ.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
||||
increase_quota=int(os.environ.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")),
|
||||
min_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")),
|
||||
max_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100"))
|
||||
)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
credentials=credentials,
|
||||
quota_unit=quota_unit,
|
||||
quotas=quotas
|
||||
)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_minimax(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if os.environ.get("HOSTED_MINIMAX_ENABLED") and os.environ.get("HOSTED_MINIMAX_ENABLED").lower() == 'true':
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
credentials=None, # use credentials from the provider
|
||||
quota_unit=quota_unit,
|
||||
quotas=quotas
|
||||
)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_spark(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if os.environ.get("HOSTED_SPARK_ENABLED") and os.environ.get("HOSTED_SPARK_ENABLED").lower() == 'true':
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
credentials=None, # use credentials from the provider
|
||||
quota_unit=quota_unit,
|
||||
quotas=quotas
|
||||
)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_zhipuai(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if os.environ.get("HOSTED_ZHIPUAI_ENABLED") and os.environ.get("HOSTED_ZHIPUAI_ENABLED").lower() == 'true':
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
credentials=None, # use credentials from the provider
|
||||
quota_unit=quota_unit,
|
||||
quotas=quotas
|
||||
)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_moderation_config(self) -> HostedModerationConfig:
|
||||
if os.environ.get("HOSTED_MODERATION_ENABLED") and os.environ.get("HOSTED_MODERATION_ENABLED").lower() == 'true' \
|
||||
and os.environ.get("HOSTED_MODERATION_PROVIDERS"):
|
||||
return HostedModerationConfig(
|
||||
enabled=True,
|
||||
providers=os.environ.get("HOSTED_MODERATION_PROVIDERS").split(',')
|
||||
)
|
||||
|
||||
return HostedModerationConfig(
|
||||
enabled=False
|
||||
)
|
|
@ -1,18 +1,12 @@
|
|||
import json
|
||||
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.dataset import Dataset
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
|
@ -22,10 +16,12 @@ class IndexBuilder:
|
|||
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
|
||||
return None
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
|
|
@ -18,9 +18,11 @@ from core.data_loader.loader.notion import NotionLoader
|
|||
from core.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import MessageType
|
||||
from core.model_manager import ModelManager
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
@ -36,6 +38,7 @@ class IndexingRunner:
|
|||
|
||||
def __init__(self):
|
||||
self.storage = storage
|
||||
self.model_manager = ModelManager()
|
||||
|
||||
def run(self, dataset_documents: List[DatasetDocument]):
|
||||
"""Run the indexing process."""
|
||||
|
@ -210,7 +213,7 @@ class IndexingRunner:
|
|||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = None
|
||||
embedding_model_instance = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
|
@ -218,15 +221,17 @@ class IndexingRunner:
|
|||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
|
@ -255,32 +260,56 @@ class IndexingRunner:
|
|||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
if indexing_technique == 'high_quality' or embedding_model:
|
||||
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
||||
if indexing_technique == 'high_quality' or embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
tokens += embedding_model_type_instance.get_num_tokens(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
texts=[self.filter_string(document.page_content)]
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
price_info = model_type_instance.get_price(
|
||||
model=model_instance.model,
|
||||
credentials=model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=total_segments * 2000,
|
||||
)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"total_price": '{:f}'.format(price_info.total_amount),
|
||||
"currency": price_info.currency,
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
|
||||
"currency": embedding_model.get_currency() if embedding_model else 'USD',
|
||||
"total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
|
||||
"currency": embedding_price_info.currency if embedding_model_instance else 'USD',
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
|
@ -290,7 +319,7 @@ class IndexingRunner:
|
|||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = None
|
||||
embedding_model_instance = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
|
@ -298,15 +327,17 @@ class IndexingRunner:
|
|||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
# load data from notion
|
||||
tokens = 0
|
||||
|
@ -349,35 +380,63 @@ class IndexingRunner:
|
|||
processing_rule=processing_rule
|
||||
)
|
||||
total_segments += len(documents)
|
||||
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
if indexing_technique == 'high_quality' or embedding_model:
|
||||
tokens += embedding_model.get_num_tokens(document.page_content)
|
||||
if indexing_technique == 'high_quality' or embedding_model_instance:
|
||||
tokens += embedding_model_type_instance.get_num_tokens(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
texts=[document.page_content]
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
|
||||
price_info = model_type_instance.get_price(
|
||||
model=model_instance.model,
|
||||
credentials=model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=total_segments * 2000,
|
||||
)
|
||||
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"total_price": '{:f}'.format(price_info.total_amount),
|
||||
"currency": price_info.currency,
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
|
||||
"currency": embedding_model.get_currency() if embedding_model else 'USD',
|
||||
"total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
|
||||
"currency": embedding_price_info.currency if embedding_model_instance else 'USD',
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
|
@ -656,25 +715,36 @@ class IndexingRunner:
|
|||
"""
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
embedding_model = None
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
chunk_size = 100
|
||||
|
||||
embedding_model_type_instance = None
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
|
||||
for i in range(0, len(documents), chunk_size):
|
||||
# check document is paused
|
||||
self._check_document_paused_status(dataset_document.id)
|
||||
chunk_documents = documents[i:i + chunk_size]
|
||||
if dataset.indexing_technique == 'high_quality' or embedding_model:
|
||||
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
|
||||
tokens += sum(
|
||||
embedding_model.get_num_tokens(document.page_content)
|
||||
embedding_model_type_instance.get_num_tokens(
|
||||
embedding_model_instance.model,
|
||||
embedding_model_instance.credentials,
|
||||
[document.page_content]
|
||||
)
|
||||
for document in chunk_documents
|
||||
)
|
||||
|
||||
|
|
|
@ -1,95 +0,0 @@
|
|||
from typing import Any, List, Dict
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage
|
||||
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
||||
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
conversation: Conversation
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "Assistant"
|
||||
model_instance: BaseLLM
|
||||
memory_key: str = "chat_history"
|
||||
max_token_limit: int = 2000
|
||||
message_limit: int = 10
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
"""String buffer of memory."""
|
||||
app_model = self.conversation.app
|
||||
|
||||
# fetch limited messages desc, and return reversed
|
||||
messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
Message.answer != ''
|
||||
).order_by(Message.created_at.desc()).limit(self.message_limit).all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
|
||||
|
||||
chat_messages: List[PromptMessage] = []
|
||||
for message in messages:
|
||||
files = message.message_files
|
||||
if files:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files, message.app_model_config
|
||||
)
|
||||
|
||||
prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
|
||||
chat_messages.append(PromptMessage(
|
||||
content=message.query,
|
||||
type=MessageType.USER,
|
||||
files=prompt_message_files
|
||||
))
|
||||
else:
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
|
||||
|
||||
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
|
||||
|
||||
if not chat_messages:
|
||||
return []
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit and chat_messages:
|
||||
pruned_memory.append(chat_messages.pop(0))
|
||||
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
|
||||
|
||||
return to_lc_messages(chat_messages)
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
buffer: Any = self.buffer
|
||||
if self.return_messages:
|
||||
final_buffer: Any = buffer
|
||||
else:
|
||||
final_buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed"""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
pass
|
|
@ -1,36 +0,0 @@
|
|||
from typing import Any, List, Dict
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string
|
||||
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
|
||||
|
||||
class ReadOnlyConversationTokenDBStringBufferSharedMemory(BaseChatMemory):
|
||||
memory: ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Return memory variables."""
|
||||
return self.memory.memory_variables
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Load memory variables from memory."""
|
||||
buffer: Any = self.memory.buffer
|
||||
|
||||
final_buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.memory.human_prefix,
|
||||
ai_prefix=self.memory.ai_prefix,
|
||||
)
|
||||
|
||||
return {self.memory.memory_key: final_buffer}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed"""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
pass
|
109
api/core/memory/token_buffer_memory.py
Normal file
109
api/core/memory/token_buffer_memory.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, TextPromptMessageContent, UserPromptMessage, \
|
||||
AssistantPromptMessage, PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
|
||||
self.conversation = conversation
|
||||
self.model_instance = model_instance
|
||||
|
||||
def get_history_prompt_messages(self, max_token_limit: int = 2000,
|
||||
message_limit: int = 10) -> list[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
app_record = self.conversation.app
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
Message.answer != ''
|
||||
).order_by(Message.created_at.desc()).limit(message_limit).all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
message_file_parser = MessageFileParser(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id
|
||||
)
|
||||
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
files = message.message_files
|
||||
if files:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files, message.app_model_config
|
||||
)
|
||||
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider)
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.model_instance.model,
|
||||
self.model_instance.credentials,
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_message_tokens > max_token_limit and prompt_messages:
|
||||
pruned_memory.append(prompt_messages.pop(0))
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.model_instance.model,
|
||||
self.model_instance.credentials,
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def get_history_prompt_text(self, human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int = 10) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit
|
||||
)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
209
api/core/model_manager.py
Normal file
209
api/core/model_manager.py
Normal file
|
@ -0,0 +1,209 @@
|
|||
from typing import Optional, Union, Generator, cast, List, IO
|
||||
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
class ModelInstance:
|
||||
"""
|
||||
Model instance class
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
|
||||
self._provider_model_bundle = provider_model_bundle
|
||||
self.model = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.model_type_instance = self._provider_model_bundle.model_type_instance
|
||||
|
||||
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
|
||||
"""
|
||||
Fetch credentials from provider model bundle
|
||||
:param provider_model_bundle: provider model bundle
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=provider_model_bundle.model_type_instance.model_type,
|
||||
model=model
|
||||
)
|
||||
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
|
||||
|
||||
return credentials
|
||||
|
||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception(f"Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception(f"Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception(f"Model type instance is not RerankModel")
|
||||
|
||||
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception(f"Model type instance is not ModerationModel")
|
||||
|
||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception(f"Model type instance is not Speech2TextModel")
|
||||
|
||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user
|
||||
)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self) -> None:
|
||||
self._provider_manager = ProviderManager()
|
||||
|
||||
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
|
||||
"""
|
||||
Get model instance
|
||||
:param tenant_id: tenant id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
|
||||
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
|
||||
"""
|
||||
Get default model instance
|
||||
:param tenant_id: tenant id
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
default_model_entity = self._provider_manager.get_default_model(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
if not default_model_entity:
|
||||
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
|
||||
|
||||
return self.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=default_model_entity.provider.provider,
|
||||
model_type=model_type,
|
||||
model=default_model_entity.model
|
||||
)
|
|
@ -1,335 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.callbacks.base import Callbacks
|
||||
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.moderation.base import BaseModeration
|
||||
from core.model_providers.models.reranking.base import BaseReranking
|
||||
from core.model_providers.models.speech2text.base import BaseSpeech2Text
|
||||
from extensions.ext_database import db
|
||||
from models.provider import TenantDefaultModel
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
|
||||
@classmethod
|
||||
def get_text_generation_model_from_model_config(cls, tenant_id: str,
|
||||
model_config: dict,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None) -> Optional[BaseLLM]:
|
||||
provider_name = model_config.get("provider")
|
||||
model_name = model_config.get("name")
|
||||
completion_params = model_config.get("completion_params", {})
|
||||
|
||||
return cls.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=completion_params.get('temperature', 0),
|
||||
max_tokens=completion_params.get('max_tokens', 256),
|
||||
top_p=completion_params.get('top_p', 0),
|
||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
||||
presence_penalty=completion_params.get('presence_penalty', 0.1)
|
||||
),
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_text_generation_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
model_kwargs: Optional[ModelKwargs] = None,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
deduct_quota: bool = True) -> Optional[BaseLLM]:
|
||||
"""
|
||||
get text generation model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:param model_kwargs:
|
||||
:param streaming:
|
||||
:param callbacks:
|
||||
:param deduct_quota:
|
||||
:return:
|
||||
"""
|
||||
is_default_model = False
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default System Reasoning Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
is_default_model = True
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init text generation model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
|
||||
|
||||
try:
|
||||
model_instance = model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
except LLMBadRequestError as e:
|
||||
if is_default_model:
|
||||
raise LLMBadRequestError(f"Default model {model_name} is not available. "
|
||||
f"Please check your model provider credentials.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
if is_default_model or not deduct_quota:
|
||||
model_instance.deduct_quota = False
|
||||
|
||||
return model_instance
|
||||
|
||||
@classmethod
|
||||
def get_embedding_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
|
||||
"""
|
||||
get embedding model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default Embedding Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init embedding model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_reranking_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None) -> Optional[BaseReranking]:
|
||||
"""
|
||||
get reranking model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
if (model_provider_name is None or len(model_provider_name) == 0) and (model_name is None or len(model_name) == 0):
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.RERANKING)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default Reranking Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init reranking model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.RERANKING)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_speech2text_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
|
||||
"""
|
||||
get speech to text model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default Speech-to-Text Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init speech to text model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_moderation_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: str,
|
||||
model_name: str) -> Optional[BaseModeration]:
|
||||
"""
|
||||
get moderation model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init moderation model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
|
||||
"""
|
||||
get default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
# get default model
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.value
|
||||
).first()
|
||||
|
||||
if not default_model:
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
||||
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
if not model_provider:
|
||||
continue
|
||||
|
||||
model_list = model_provider.get_supported_model_list(model_type)
|
||||
if model_list:
|
||||
model_info = model_list[0]
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.value,
|
||||
provider_name=model_provider_name,
|
||||
model_name=model_info['id']
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
break
|
||||
|
||||
return default_model
|
||||
|
||||
@classmethod
|
||||
def update_default_model(cls,
|
||||
tenant_id: str,
|
||||
model_type: ModelType,
|
||||
provider_name: str,
|
||||
model_name: str) -> TenantDefaultModel:
|
||||
"""
|
||||
update default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
model_provider_name = ModelProviderFactory.get_provider_names()
|
||||
if provider_name not in model_provider_name:
|
||||
raise ValueError(f'Invalid provider name: {provider_name}')
|
||||
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
model_list = model_provider.get_supported_model_list(model_type)
|
||||
model_ids = [model['id'] for model in model_list]
|
||||
if model_name not in model_ids:
|
||||
raise ValueError(f'Invalid model name: {model_name}')
|
||||
|
||||
# get default model
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.value
|
||||
).first()
|
||||
|
||||
if default_model:
|
||||
# update default model
|
||||
default_model.provider_name = provider_name
|
||||
default_model.model_name = model_name
|
||||
db.session.commit()
|
||||
else:
|
||||
# create default model
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.value,
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
|
||||
return default_model
|
|
@ -1,276 +0,0 @@
|
|||
from typing import Type
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.rules import provider_rules
|
||||
from extensions.ext_database import db
|
||||
from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
|
||||
|
||||
DEFAULT_MODELS = {
|
||||
ModelType.TEXT_GENERATION.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'gpt-3.5-turbo',
|
||||
},
|
||||
ModelType.EMBEDDINGS.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'text-embedding-ada-002',
|
||||
},
|
||||
ModelType.SPEECH_TO_TEXT.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'whisper-1',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
@classmethod
|
||||
def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
|
||||
if provider_name == 'openai':
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
return OpenAIProvider
|
||||
elif provider_name == 'anthropic':
|
||||
from core.model_providers.providers.anthropic_provider import AnthropicProvider
|
||||
return AnthropicProvider
|
||||
elif provider_name == 'minimax':
|
||||
from core.model_providers.providers.minimax_provider import MinimaxProvider
|
||||
return MinimaxProvider
|
||||
elif provider_name == 'spark':
|
||||
from core.model_providers.providers.spark_provider import SparkProvider
|
||||
return SparkProvider
|
||||
elif provider_name == 'tongyi':
|
||||
from core.model_providers.providers.tongyi_provider import TongyiProvider
|
||||
return TongyiProvider
|
||||
elif provider_name == 'wenxin':
|
||||
from core.model_providers.providers.wenxin_provider import WenxinProvider
|
||||
return WenxinProvider
|
||||
elif provider_name == 'zhipuai':
|
||||
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
|
||||
return ZhipuAIProvider
|
||||
elif provider_name == 'chatglm':
|
||||
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
|
||||
return ChatGLMProvider
|
||||
elif provider_name == 'baichuan':
|
||||
from core.model_providers.providers.baichuan_provider import BaichuanProvider
|
||||
return BaichuanProvider
|
||||
elif provider_name == 'azure_openai':
|
||||
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
return AzureOpenAIProvider
|
||||
elif provider_name == 'replicate':
|
||||
from core.model_providers.providers.replicate_provider import ReplicateProvider
|
||||
return ReplicateProvider
|
||||
elif provider_name == 'huggingface_hub':
|
||||
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
|
||||
return HuggingfaceHubProvider
|
||||
elif provider_name == 'xinference':
|
||||
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||
return XinferenceProvider
|
||||
elif provider_name == 'openllm':
|
||||
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
||||
return OpenLLMProvider
|
||||
elif provider_name == 'localai':
|
||||
from core.model_providers.providers.localai_provider import LocalAIProvider
|
||||
return LocalAIProvider
|
||||
elif provider_name == 'cohere':
|
||||
from core.model_providers.providers.cohere_provider import CohereProvider
|
||||
return CohereProvider
|
||||
elif provider_name == 'jina':
|
||||
from core.model_providers.providers.jina_provider import JinaProvider
|
||||
return JinaProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_provider_names(cls):
|
||||
"""
|
||||
Returns a list of provider names.
|
||||
"""
|
||||
return list(provider_rules.keys())
|
||||
|
||||
@classmethod
|
||||
def get_provider_rules(cls):
|
||||
"""
|
||||
Returns a list of provider rules.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return provider_rules
|
||||
|
||||
@classmethod
|
||||
def get_provider_rule(cls, provider_name: str):
|
||||
"""
|
||||
Returns provider rule.
|
||||
"""
|
||||
return provider_rules[provider_name]
|
||||
|
||||
@classmethod
|
||||
def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred model provider.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
# get preferred provider
|
||||
preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
|
||||
if not preferred_provider or not preferred_provider.is_valid:
|
||||
return None
|
||||
|
||||
# init model provider
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
|
||||
return model_provider_class(provider=preferred_provider)
|
||||
|
||||
@classmethod
|
||||
def get_preferred_type_by_preferred_model_provider(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: str,
|
||||
preferred_model_provider: TenantPreferredModelProvider):
|
||||
"""
|
||||
get preferred provider type by preferred model provider.
|
||||
|
||||
:param model_provider_name:
|
||||
:param preferred_model_provider:
|
||||
:return:
|
||||
"""
|
||||
if not preferred_model_provider:
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
||||
support_provider_types = model_provider_rules['support_provider_types']
|
||||
|
||||
if ProviderType.CUSTOM.value in support_provider_types:
|
||||
custom_provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.is_valid == True
|
||||
).first()
|
||||
|
||||
if custom_provider:
|
||||
return ProviderType.CUSTOM.value
|
||||
|
||||
model_provider = cls.get_model_provider_class(model_provider_name)
|
||||
|
||||
if ProviderType.SYSTEM.value in support_provider_types \
|
||||
and model_provider.is_provider_type_system_supported():
|
||||
return ProviderType.SYSTEM.value
|
||||
elif ProviderType.CUSTOM.value in support_provider_types:
|
||||
return ProviderType.CUSTOM.value
|
||||
else:
|
||||
return preferred_model_provider.preferred_provider_type
|
||||
|
||||
@classmethod
|
||||
def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred provider of tenant.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
# get preferred provider type
|
||||
preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
|
||||
|
||||
# get providers by preferred provider type
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == preferred_provider_type
|
||||
).all()
|
||||
|
||||
no_system_provider = False
|
||||
if preferred_provider_type == ProviderType.SYSTEM.value:
|
||||
quota_type_to_provider_dict = {}
|
||||
for provider in providers:
|
||||
quota_type_to_provider_dict[provider.quota_type] = provider
|
||||
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
||||
for quota_type_enum in ProviderQuotaType:
|
||||
quota_type = quota_type_enum.value
|
||||
if quota_type in model_provider_rules['system_config']['supported_quota_types']:
|
||||
if quota_type in quota_type_to_provider_dict.keys():
|
||||
provider = quota_type_to_provider_dict[quota_type]
|
||||
if provider.is_valid and provider.quota_limit > provider.quota_used:
|
||||
return provider
|
||||
elif quota_type == ProviderQuotaType.TRIAL.value:
|
||||
try:
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=model_provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
is_valid=True,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=model_provider_rules['system_config']['quota_limit'],
|
||||
quota_used=0
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value
|
||||
).first()
|
||||
|
||||
if provider.quota_limit == 0:
|
||||
return None
|
||||
|
||||
return provider
|
||||
|
||||
no_system_provider = True
|
||||
|
||||
if no_system_provider:
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).all()
|
||||
|
||||
if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
|
||||
if providers:
|
||||
return providers[0]
|
||||
else:
|
||||
try:
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=model_provider_name,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred provider type of tenant.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == model_provider_name
|
||||
).first()
|
||||
|
||||
return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user