mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat: server multi models support (#799)
This commit is contained in:
parent
d8b712b325
commit
5fa2161b05
|
@ -19,7 +19,8 @@ def check_file_for_chinese_comments(file_path):
|
|||
|
||||
def main():
|
||||
has_chinese = False
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py']
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
|
||||
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py']
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
for file in files:
|
||||
|
|
|
@ -102,3 +102,29 @@ NOTION_INTEGRATION_TYPE=public
|
|||
NOTION_CLIENT_SECRET=you-client-secret
|
||||
NOTION_CLIENT_ID=you-client-id
|
||||
NOTION_INTERNAL_SECRET=you-internal-secret
|
||||
|
||||
# Hosted Model Credentials
|
||||
HOSTED_OPENAI_ENABLED=false
|
||||
HOSTED_OPENAI_API_KEY=
|
||||
HOSTED_OPENAI_API_BASE=
|
||||
HOSTED_OPENAI_API_ORGANIZATION=
|
||||
HOSTED_OPENAI_QUOTA_LIMIT=200
|
||||
HOSTED_OPENAI_PAID_ENABLED=false
|
||||
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
|
||||
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
|
||||
|
||||
HOSTED_AZURE_OPENAI_ENABLED=false
|
||||
HOSTED_AZURE_OPENAI_API_KEY=
|
||||
HOSTED_AZURE_OPENAI_API_BASE=
|
||||
HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
|
||||
|
||||
HOSTED_ANTHROPIC_ENABLED=false
|
||||
HOSTED_ANTHROPIC_API_BASE=
|
||||
HOSTED_ANTHROPIC_API_KEY=
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
||||
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
|
||||
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
|
||||
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
19
api/app.py
19
api/app.py
|
@ -16,8 +16,9 @@ from flask import Flask, request, Response, session
|
|||
import flask_login
|
||||
from flask_cors import CORS
|
||||
|
||||
from core.model_providers.providers import hosted
|
||||
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage, ext_mail
|
||||
ext_database, ext_storage, ext_mail, ext_stripe
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
|
@ -71,7 +72,7 @@ def create_app(test_config=None) -> Flask:
|
|||
register_blueprints(app)
|
||||
register_commands(app)
|
||||
|
||||
core.init_app(app)
|
||||
hosted.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
@ -88,6 +89,7 @@ def initialize_extensions(app):
|
|||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
ext_stripe.init_app(app)
|
||||
|
||||
|
||||
def _create_tenant_for_account(account):
|
||||
|
@ -246,5 +248,18 @@ def threads():
|
|||
}
|
||||
|
||||
|
||||
@app.route('/db-pool-stat')
|
||||
def pool_stat():
|
||||
engine = db.engine
|
||||
return {
|
||||
'pool_size': engine.pool.size(),
|
||||
'checked_in_connections': engine.pool.checkedin(),
|
||||
'checked_out_connections': engine.pool.checkedout(),
|
||||
'overflow_connections': engine.pool.overflow(),
|
||||
'connection_timeout': engine.pool.timeout(),
|
||||
'recycle_time': db.engine.pool._recycle
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0', port=5001)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import datetime
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
@ -9,18 +9,18 @@ from flask import current_app
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from libs.password import password_pattern, valid_password, hash_password
|
||||
from libs.helper import email as email_validate
|
||||
from extensions.ext_database import db
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
|
||||
from models.dataset import Dataset, DatasetQuery, Document
|
||||
from models.model import Account
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
from models.provider import Provider, ProviderName
|
||||
from services.provider_service import ProviderService
|
||||
from models.provider import Provider, ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
|
@ -251,26 +251,37 @@ def clean_unused_dataset_indexes():
|
|||
|
||||
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
|
||||
def sync_anthropic_hosted_providers():
|
||||
if not hosted_model_providers.anthropic:
|
||||
click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
|
||||
return
|
||||
|
||||
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
|
||||
count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.provider_name == 'anthropic',
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for tenant in tenants:
|
||||
for provider in providers:
|
||||
try:
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.ANTHROPIC.value,
|
||||
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
|
||||
original_quota_limit = provider.quota_limit
|
||||
new_quota_limit = hosted_model_providers.anthropic.quota_limit
|
||||
division = math.ceil(new_quota_limit / 1000)
|
||||
|
||||
provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
|
||||
else original_quota_limit * division
|
||||
provider.quota_used = division * provider.quota_used
|
||||
db.session.commit()
|
||||
|
||||
count += 1
|
||||
except Exception as e:
|
||||
click.echo(click.style(
|
||||
|
|
|
@ -41,6 +41,7 @@ DEFAULTS = {
|
|||
'SESSION_USE_SIGNER': 'True',
|
||||
'DEPLOY_ENV': 'PRODUCTION',
|
||||
'SQLALCHEMY_POOL_SIZE': 30,
|
||||
'SQLALCHEMY_POOL_RECYCLE': 3600,
|
||||
'SQLALCHEMY_ECHO': 'False',
|
||||
'SENTRY_TRACES_SAMPLE_RATE': 1.0,
|
||||
'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
|
||||
|
@ -50,9 +51,16 @@ DEFAULTS = {
|
|||
'PDF_PREVIEW': 'True',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
|
||||
'DEFAULT_LLM_PROVIDER': 'openai',
|
||||
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
|
||||
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
|
||||
'HOSTED_ANTHROPIC_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
|
||||
'TENANT_DOCUMENT_COUNT': 100,
|
||||
'CLEAN_DAY_SETTING': 30
|
||||
}
|
||||
|
@ -182,7 +190,10 @@ class Config:
|
|||
}
|
||||
|
||||
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
|
||||
self.SQLALCHEMY_ENGINE_OPTIONS = {'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE'))}
|
||||
self.SQLALCHEMY_ENGINE_OPTIONS = {
|
||||
'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
|
||||
'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
|
||||
}
|
||||
|
||||
self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
|
||||
|
||||
|
@ -194,20 +205,35 @@ class Config:
|
|||
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
|
||||
|
||||
# hosted provider credentials
|
||||
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
|
||||
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
|
||||
self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
|
||||
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
|
||||
|
||||
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
|
||||
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
|
||||
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
|
||||
|
||||
self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
|
||||
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
|
||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
|
||||
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
|
||||
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
# By default it is False
|
||||
# You could disable it for compatibility with certain OpenAPI providers
|
||||
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
|
||||
|
||||
# For temp use only
|
||||
# set default LLM provider, default is 'openai', support `azure_openai`
|
||||
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
|
||||
|
||||
# notion import setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
|
|
|
@ -18,10 +18,13 @@ from .auth import login, oauth, data_source_oauth, activate
|
|||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import workspace, members, model_providers, account, tool_providers
|
||||
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
|
||||
|
||||
# Import universal chat controllers
|
||||
from .universal_chat import chat, conversation, message, parameter, audio
|
||||
|
||||
# Import webhook controllers
|
||||
from .webhook import stripe
|
||||
|
|
|
@ -2,16 +2,17 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import flask
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
|
||||
from werkzeug.exceptions import Unauthorized, Forbidden
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from constants.model_template import model_templates, demo_model_templates
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
|
@ -126,9 +127,9 @@ class AppListApi(Resource):
|
|||
if args['model_config'] is not None:
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=args['model_config'],
|
||||
mode=args['mode']
|
||||
config=args['model_config']
|
||||
)
|
||||
|
||||
app = App(
|
||||
|
@ -164,6 +165,21 @@ class AppListApi(Resource):
|
|||
app = App(**model_config_template['app'])
|
||||
app_model_config = AppModelConfig(**model_config_template['model_config'])
|
||||
|
||||
default_model = ModelFactory.get_default_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.TEXT_GENERATION
|
||||
)
|
||||
|
||||
if default_model:
|
||||
model_dict = app_model_config.model_dict
|
||||
model_dict['provider'] = default_model.provider_name
|
||||
model_dict['name'] = default_model.model_name
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
else:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Text Generation Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
app.name = args['name']
|
||||
app.mode = args['mode']
|
||||
app.icon = args['icon']
|
||||
|
|
|
@ -14,7 +14,7 @@ from controllers.console.app.error import AppUnavailableError, \
|
|||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from flask_restful import Resource
|
||||
from services.audio_service import AudioService
|
||||
|
|
|
@ -17,7 +17,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
|
|||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value
|
||||
from flask_restful import Resource, reqparse
|
||||
|
@ -41,8 +41,11 @@ class CompletionMessageApi(Resource):
|
|||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
try:
|
||||
|
@ -51,7 +54,7 @@ class CompletionMessageApi(Resource):
|
|||
user=account,
|
||||
args=args,
|
||||
from_source='console',
|
||||
streaming=True,
|
||||
streaming=streaming,
|
||||
is_model_config_override=True
|
||||
)
|
||||
|
||||
|
@ -111,8 +114,11 @@ class ChatMessageApi(Resource):
|
|||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
try:
|
||||
|
@ -121,7 +127,7 @@ class ChatMessageApi(Resource):
|
|||
user=account,
|
||||
args=args,
|
||||
from_source='console',
|
||||
streaming=True,
|
||||
streaming=streaming,
|
||||
is_model_config_override=True
|
||||
)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
|||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
|
||||
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
|
|||
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
|
|
|
@ -28,9 +28,9 @@ class ModelConfigResource(Resource):
|
|||
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=request.json,
|
||||
mode=app_model.mode
|
||||
config=request.json
|
||||
)
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
|
|
|
@ -255,7 +255,7 @@ class DataSourceNotionApi(Resource):
|
|||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule'])
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
|
||||
return response, 200
|
||||
|
||||
|
||||
|
|
|
@ -5,10 +5,13 @@ from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
|||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Document
|
||||
|
@ -97,6 +100,15 @@ class DatasetListApi(Resource):
|
|||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
@ -235,12 +247,26 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form'])
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
args['process_rule'], args['doc_form'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'])
|
||||
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response, 200
|
||||
|
|
|
@ -18,7 +18,9 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
|
|||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
|
@ -280,6 +282,15 @@ class DatasetDocumentListApi(Resource):
|
|||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
@ -319,6 +330,15 @@ class DatasetInitApi(Resource):
|
|||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
|
@ -384,7 +404,13 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict)
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
|
||||
data_process_rule_dict)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
return response
|
||||
|
||||
|
@ -445,12 +471,24 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict)
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
data_process_rule_dict)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
elif dataset.data_source_type:
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(info_list,
|
||||
data_process_rule_dict)
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
info_list,
|
||||
data_process_rule_dict)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response
|
||||
|
|
|
@ -11,7 +11,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
|||
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
@ -102,6 +102,8 @@ class HitTestingApi(Resource):
|
|||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
logging.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
||||
|
|
|
@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
|
|||
NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
|
|
|
@ -15,7 +15,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
|
|||
from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
|
|
@ -15,7 +15,7 @@ from controllers.console.app.error import AppMoreLikeThisDisabledError, Provider
|
|||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.completion_service import CompletionService
|
||||
|
|
|
@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
|
|||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import InstalledApp
|
||||
|
||||
|
||||
|
@ -35,13 +33,12 @@ class AppParameterApi(InstalledAppResource):
|
|||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
|
|||
NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
|
|
|
@ -12,9 +12,8 @@ from controllers.console import api
|
|||
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.constant import llm_constant
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
@ -27,6 +26,7 @@ class UniversalChatApi(UniversalChatResource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, location='json')
|
||||
parser.add_argument('model', type=str, required=True, location='json')
|
||||
parser.add_argument('tools', type=list, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
@ -36,11 +36,7 @@ class UniversalChatApi(UniversalChatResource):
|
|||
# update app model config
|
||||
args['model_config'] = app_model_config.to_dict()
|
||||
args['model_config']['model']['name'] = args['model']
|
||||
|
||||
if not llm_constant.models[args['model']]:
|
||||
raise ValueError("Model not exists.")
|
||||
|
||||
args['model_config']['model']['provider'] = llm_constant.models[args['model']]
|
||||
args['model_config']['model']['provider'] = args['provider']
|
||||
args['model_config']['agent_mode']['tools'] = args['tools']
|
||||
|
||||
if not args['model_config']['agent_mode']['tools']:
|
||||
|
|
|
@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, \
|
|||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
|
|
@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
|
|||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
|
@ -23,13 +21,12 @@ class UniversalChatParameterApi(UniversalChatResource):
|
|||
"""Retrieve app parameters."""
|
||||
app_model = universal_app
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
}
|
||||
|
||||
|
||||
|
|
53
api/controllers/console/webhook/stripe.py
Normal file
53
api/controllers/console/webhook/stripe.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import logging
|
||||
|
||||
import stripe
|
||||
from flask import request, current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from services.provider_checkout_service import ProviderCheckoutService
|
||||
|
||||
|
||||
class StripeWebhookApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
payload = request.data
|
||||
sig_header = request.headers.get('STRIPE_SIGNATURE')
|
||||
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
return 'Invalid payload', 400
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
return 'Invalid signature', 400
|
||||
|
||||
# Handle the checkout.session.completed event
|
||||
if event['type'] == 'checkout.session.completed':
|
||||
logging.debug(event['data']['object']['id'])
|
||||
logging.debug(event['data']['object']['amount_subtotal'])
|
||||
logging.debug(event['data']['object']['currency'])
|
||||
logging.debug(event['data']['object']['payment_intent'])
|
||||
logging.debug(event['data']['object']['payment_status'])
|
||||
logging.debug(event['data']['object']['metadata'])
|
||||
|
||||
# Fulfill the purchase...
|
||||
provider_checkout_service = ProviderCheckoutService()
|
||||
|
||||
try:
|
||||
provider_checkout_service.fulfill_provider_order(event)
|
||||
except Exception as e:
|
||||
logging.debug(str(e))
|
||||
return 'success', 200
|
||||
|
||||
return 'success', 200
|
||||
|
||||
|
||||
api.add_resource(StripeWebhookApi, '/webhook/stripe')
|
|
@ -1,24 +1,18 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, abort
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.provider import Provider, ProviderType, ProviderName
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from services.provider_checkout_service import ProviderCheckoutService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
class ProviderListApi(Resource):
|
||||
class ModelProviderListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -26,156 +20,36 @@ class ProviderListApi(Resource):
|
|||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
"""
|
||||
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
|
||||
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
|
||||
rest is replaced by * and the last two bits are displayed in plaintext
|
||||
|
||||
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
|
||||
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
||||
"""
|
||||
|
||||
ProviderService.init_supported_provider(current_user.current_tenant)
|
||||
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
|
||||
|
||||
provider_list = [
|
||||
{
|
||||
'provider_name': p.provider_name,
|
||||
'provider_type': p.provider_type,
|
||||
'is_valid': p.is_valid,
|
||||
'last_used': p.last_used,
|
||||
'is_enabled': p.is_enabled,
|
||||
**({
|
||||
'quota_type': p.quota_type,
|
||||
'quota_limit': p.quota_limit,
|
||||
'quota_used': p.quota_used
|
||||
} if p.provider_type == ProviderType.SYSTEM.value else {}),
|
||||
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
|
||||
ProviderName(p.provider_name), only_custom=True)
|
||||
if p.provider_type == ProviderType.CUSTOM.value else None
|
||||
}
|
||||
for p in providers
|
||||
]
|
||||
provider_service = ProviderService()
|
||||
provider_list = provider_service.get_provider_list(tenant_id)
|
||||
|
||||
return provider_list
|
||||
|
||||
|
||||
class ProviderTokenApi(Resource):
|
||||
class ModelProviderValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ProviderName]:
|
||||
abort(404)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
logging.log(logging.ERROR,
|
||||
f'User {current_user.id} is not authorized to update provider token, current_role is {current_user.current_tenant.current_role}')
|
||||
raise Forbidden()
|
||||
def post(self, provider_name: str):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('token', type=ProviderService.get_token_type(
|
||||
tenant=current_user.current_tenant,
|
||||
provider_name=ProviderName(provider)
|
||||
), required=True, nullable=False, location='json')
|
||||
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args['token']:
|
||||
try:
|
||||
ProviderService.validate_provider_configs(
|
||||
tenant=current_user.current_tenant,
|
||||
provider_name=ProviderName(provider),
|
||||
configs=args['token']
|
||||
)
|
||||
token_is_valid = True
|
||||
except ValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
base64_encrypted_token = ProviderService.get_encrypted_token(
|
||||
tenant=current_user.current_tenant,
|
||||
provider_name=ProviderName(provider),
|
||||
configs=args['token']
|
||||
)
|
||||
else:
|
||||
base64_encrypted_token = None
|
||||
token_is_valid = False
|
||||
|
||||
tenant = current_user.current_tenant
|
||||
|
||||
provider_model = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id,
|
||||
Provider.provider_name == provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
# Only allow updating token for CUSTOM provider type
|
||||
if provider_model:
|
||||
provider_model.encrypted_config = base64_encrypted_token
|
||||
provider_model.is_valid = token_is_valid
|
||||
else:
|
||||
provider_model = Provider(tenant_id=tenant.id, provider_name=provider,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=base64_encrypted_token,
|
||||
is_valid=token_is_valid)
|
||||
db.session.add(provider_model)
|
||||
|
||||
if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
|
||||
other_providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id,
|
||||
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
|
||||
Provider.provider_name != provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).all()
|
||||
|
||||
for other_provider in other_providers:
|
||||
other_provider.is_valid = False
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
||||
ProviderName.HUGGINGFACEHUB.value]:
|
||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
|
||||
class ProviderTokenValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ProviderName]:
|
||||
abort(404)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=ProviderService.get_token_type(
|
||||
tenant=current_user.current_tenant,
|
||||
provider_name=ProviderName(provider)
|
||||
), required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# todo: remove this when the provider is supported
|
||||
if provider in [ProviderName.COHERE.value,
|
||||
ProviderName.HUGGINGFACEHUB.value]:
|
||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
|
||||
provider_service = ProviderService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
ProviderService.validate_provider_configs(
|
||||
tenant=current_user.current_tenant,
|
||||
provider_name=ProviderName(provider),
|
||||
configs=args['token']
|
||||
provider_service.custom_provider_config_validate(
|
||||
provider_name=provider_name,
|
||||
config=args['config']
|
||||
)
|
||||
except ValidateFailedError as e:
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(e)
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
|
@ -185,91 +59,227 @@ class ProviderTokenValidateApi(Resource):
|
|||
return response
|
||||
|
||||
|
||||
class ProviderSystemApi(Resource):
|
||||
class ModelProviderUpdateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, provider):
|
||||
if provider not in [p.value for p in ProviderName]:
|
||||
abort(404)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('is_enabled', type=bool, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant = current_user.current_tenant_id
|
||||
|
||||
provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider).first()
|
||||
|
||||
if provider_model and provider_model.provider_type == ProviderType.SYSTEM.value:
|
||||
provider_model.is_valid = args['is_enabled']
|
||||
db.session.commit()
|
||||
elif not provider_model:
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
|
||||
elif provider == ProviderName.ANTHROPIC.value:
|
||||
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
|
||||
else:
|
||||
quota_limit = 0
|
||||
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
provider,
|
||||
quota_limit,
|
||||
args['is_enabled']
|
||||
)
|
||||
else:
|
||||
abort(403)
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
if provider not in [p.value for p in ProviderName]:
|
||||
abort(404)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
def post(self, provider_name: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
provider_model = db.session.query(Provider).filter(Provider.tenant_id == current_user.current_tenant_id,
|
||||
Provider.provider_name == provider,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value).first()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
system_model = None
|
||||
if provider_model:
|
||||
system_model = {
|
||||
'result': 'success',
|
||||
'provider': {
|
||||
'provider_name': provider_model.provider_name,
|
||||
'provider_type': provider_model.provider_type,
|
||||
'is_valid': provider_model.is_valid,
|
||||
'last_used': provider_model.last_used,
|
||||
'is_enabled': provider_model.is_enabled,
|
||||
'quota_type': provider_model.quota_type,
|
||||
'quota_limit': provider_model.quota_limit,
|
||||
'quota_used': provider_model.quota_used
|
||||
}
|
||||
provider_service = ProviderService()
|
||||
|
||||
try:
|
||||
provider_service.save_custom_provider_config(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
config=args['config']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_name: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.delete_custom_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name
|
||||
)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
provider_service.custom_provider_model_config_validate(
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type'],
|
||||
config=args['config']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ModelProviderModelUpdateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
try:
|
||||
provider_service.add_or_save_custom_provider_model_config(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type'],
|
||||
config=args['config']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_name: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.delete_custom_provider_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type']
|
||||
)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
|
||||
choices=['system', 'custom'], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.switch_preferred_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
preferred_provider_type=args['preferred_provider_type']
|
||||
)
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
try:
|
||||
parameter_rules = provider_service.get_model_parameter_rules(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type='text-generation'
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"Current Text Generation Model is invalid. Please switch to the available model.")
|
||||
|
||||
rules = {
|
||||
k: {
|
||||
'enabled': v.enabled,
|
||||
'min': v.min,
|
||||
'max': v.max,
|
||||
'default': v.default
|
||||
}
|
||||
else:
|
||||
abort(404)
|
||||
for k, v in vars(parameter_rules).items()
|
||||
}
|
||||
|
||||
return system_model
|
||||
return rules
|
||||
|
||||
|
||||
api.add_resource(ProviderTokenApi, '/providers/<provider>/token',
|
||||
endpoint='current_providers_token') # Deprecated
|
||||
api.add_resource(ProviderTokenValidateApi, '/providers/<provider>/token-validate',
|
||||
endpoint='current_providers_token_validate') # Deprecated
|
||||
class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
provider_service = ProviderCheckoutService()
|
||||
provider_checkout = provider_service.create_checkout(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
account=current_user
|
||||
)
|
||||
|
||||
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
|
||||
endpoint='workspaces_current_providers_token') # PUT for updating provider token
|
||||
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
|
||||
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
|
||||
return {
|
||||
'url': provider_checkout.get_checkout_url()
|
||||
}
|
||||
|
||||
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
|
||||
api.add_resource(ProviderSystemApi, '/workspaces/current/providers/<provider>/system',
|
||||
endpoint='workspaces_current_providers_system') # GET for getting provider quota, PUT for updating provider status
|
||||
|
||||
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
|
||||
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
|
||||
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
|
||||
api.add_resource(ModelProviderModelValidateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models/validate')
|
||||
api.add_resource(ModelProviderModelUpdateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models')
|
||||
api.add_resource(PreferredProviderTypeUpdateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
|
||||
api.add_resource(ModelProviderModelParameterRuleApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
|
||||
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
|
||||
|
|
108
api/controllers/console/workspace/models.py
Normal file
108
api/controllers/console/workspace/models.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from models.provider import ProviderType
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
class DefaultModelApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
provider_service = ProviderService()
|
||||
default_model = provider_service.get_default_model_of_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type=args['model_type']
|
||||
)
|
||||
|
||||
if not default_model:
|
||||
return None
|
||||
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
tenant_id,
|
||||
default_model.provider_name
|
||||
)
|
||||
|
||||
if not model_provider:
|
||||
return {
|
||||
'model_name': default_model.model_name,
|
||||
'model_type': default_model.model_type,
|
||||
'model_provider': {
|
||||
'provider_name': default_model.provider_name
|
||||
}
|
||||
}
|
||||
|
||||
provider = model_provider.provider
|
||||
rst = {
|
||||
'model_name': default_model.model_name,
|
||||
'model_type': default_model.model_type,
|
||||
'model_provider': {
|
||||
'provider_name': provider.provider_name,
|
||||
'provider_type': provider.provider_type
|
||||
}
|
||||
}
|
||||
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
rst['model_provider']['quota_type'] = provider.quota_type
|
||||
rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
|
||||
rst['model_provider']['quota_limit'] = provider.quota_limit
|
||||
rst['model_provider']['quota_used'] = provider.quota_used
|
||||
|
||||
return rst
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
|
||||
parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.update_default_model_of_model_type(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=args['model_type'],
|
||||
provider_name=args['provider_name'],
|
||||
model_name=args['model_name']
|
||||
)
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class ValidModelApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, model_type):
|
||||
ModelType.value_of(model_type)
|
||||
|
||||
provider_service = ProviderService()
|
||||
valid_models = provider_service.get_valid_model_list(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
return valid_models
|
||||
|
||||
|
||||
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
|
||||
api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')
|
130
api/controllers/console/workspace/providers.py
Normal file
130
api/controllers/console/workspace/providers.py
Normal file
|
@ -0,0 +1,130 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from models.provider import ProviderType
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
class ProviderListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
"""
|
||||
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
|
||||
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
|
||||
rest is replaced by * and the last two bits are displayed in plaintext
|
||||
|
||||
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
|
||||
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
||||
"""
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_info_list = provider_service.get_provider_list(tenant_id)
|
||||
|
||||
provider_list = [
|
||||
{
|
||||
'provider_name': p['provider_name'],
|
||||
'provider_type': p['provider_type'],
|
||||
'is_valid': p['is_valid'],
|
||||
'last_used': p['last_used'],
|
||||
'is_enabled': p['is_valid'],
|
||||
**({
|
||||
'quota_type': p['quota_type'],
|
||||
'quota_limit': p['quota_limit'],
|
||||
'quota_used': p['quota_used']
|
||||
} if p['provider_type'] == ProviderType.SYSTEM.value else {}),
|
||||
'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
|
||||
if p['config'] else None
|
||||
}
|
||||
for name, provider_info in provider_info_list.items()
|
||||
for p in provider_info['providers']
|
||||
]
|
||||
|
||||
return provider_list
|
||||
|
||||
|
||||
class ProviderTokenApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if provider == 'openai':
|
||||
args['token'] = {
|
||||
'openai_api_key': args['token']
|
||||
}
|
||||
|
||||
provider_service = ProviderService()
|
||||
try:
|
||||
provider_service.save_custom_provider_config(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider,
|
||||
config=args['token']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
|
||||
class ProviderTokenValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
if provider == 'openai':
|
||||
args['token'] = {
|
||||
'openai_api_key': args['token']
|
||||
}
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
provider_service.custom_provider_config_validate(
|
||||
provider_name=provider,
|
||||
config=args['token']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
|
||||
endpoint='workspaces_current_providers_token') # PUT for updating provider token
|
||||
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
|
||||
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
|
||||
|
||||
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
|
|
@ -30,7 +30,7 @@ tenant_fields = {
|
|||
'created_at': TimestampField,
|
||||
'role': fields.String,
|
||||
'providers': fields.List(fields.Nested(provider_fields)),
|
||||
'in_trail': fields.Boolean,
|
||||
'in_trial': fields.Boolean,
|
||||
'trial_end_reason': fields.String,
|
||||
}
|
||||
|
||||
|
|
|
@ -4,8 +4,6 @@ from flask_restful import fields, marshal_with
|
|||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
|
@ -35,13 +33,12 @@ class AppParameterApi(AppApiResource):
|
|||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
|
|||
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
|
||||
ProviderNotSupportSpeechToTextError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from models.model import App, AppModelConfig
|
||||
from services.audio_service import AudioService
|
||||
|
|
|
@ -14,7 +14,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
|
|||
ProviderModelCurrentlyNotSupportError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
|
|
@ -11,7 +11,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError
|
|||
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
|
||||
DatasetNotInitedError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
|
|
@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
|
|||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
|
@ -34,13 +32,12 @@ class AppParameterApi(WebApiResource):
|
|||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
|
|||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
|
|
|
@ -14,7 +14,7 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
|
|||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
|
|
@ -14,7 +14,7 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
|
|||
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.completion_service import CompletionService
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
import langchain
|
||||
from flask import Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.prompt.prompt_template import OneLineFormatter
|
||||
|
||||
|
||||
class HostedOpenAICredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedAnthropicCredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedLLMCredentials(BaseModel):
|
||||
openai: Optional[HostedOpenAICredential] = None
|
||||
anthropic: Optional[HostedAnthropicCredential] = None
|
||||
|
||||
|
||||
hosted_llm_credentials = HostedLLMCredentials()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
langchain.verbose = True
|
||||
|
||||
if app.config.get("OPENAI_API_KEY"):
|
||||
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
|
||||
|
||||
if app.config.get("ANTHROPIC_API_KEY"):
|
||||
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))
|
|
@ -1,20 +1,17 @@
|
|||
from typing import cast, List
|
||||
from typing import List
|
||||
|
||||
from langchain import OpenAI
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
return llm.get_num_tokens_from_messages(messages)
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
return model_instance.get_num_tokens(to_prompt_messages(messages))
|
||||
|
||||
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
|
@ -22,10 +19,9 @@ class CalcTokenMixin:
|
|||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
|
||||
completion_max_tokens = llm.max_tokens
|
||||
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
|
||||
llm_max_tokens = model_instance.model_rules.max_tokens.max
|
||||
completion_max_tokens = model_instance.model_kwargs.max_tokens
|
||||
used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
|
||||
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
|
||||
|
||||
return rest_tokens
|
||||
|
|
|
@ -4,9 +4,11 @@ from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
|||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
|
@ -14,6 +16,12 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|||
"""
|
||||
An Multi Dataset Retrieve Agent driven by Router.
|
||||
"""
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
|
|
|
@ -6,7 +6,8 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
|||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
|
@ -84,7 +85,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
|||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
|
||||
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
|
|
|
@ -3,20 +3,28 @@ from typing import cast, List
|
|||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chat_models.openai import _convert_message_to_dict
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
|
||||
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
|
||||
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
model_instance: BaseLLM
|
||||
|
||||
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
|
|
@ -6,7 +6,8 @@ from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFuncti
|
|||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
|
@ -84,7 +85,7 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
|
|||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
|
||||
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
|
|
162
api/core/agent/agent/structed_multi_dataset_router_agent.py
Normal file
162
api/core/agent/agent/structed_multi_dataset_router_agent.py
Normal file
|
@ -0,0 +1,162 @@
|
|||
import re
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}}}
|
||||
```"""
|
||||
|
||||
|
||||
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
model_instance: BaseLLM
|
||||
dataset_tools: Sequence[BaseTool]
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(self.dataset_tools) == 0:
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.dataset_tools) == 1:
|
||||
tool = next(iter(self.dataset_tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
|
||||
try:
|
||||
return self.output_parser.parse(full_output)
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
unique_tool_names = set(tool.name for tool in tools)
|
||||
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
dataset_tools=tools,
|
||||
**kwargs,
|
||||
)
|
|
@ -14,7 +14,7 @@ from langchain.tools import BaseTool
|
|||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
|
@ -53,6 +53,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
|
@ -89,7 +95,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ import logging
|
|||
from typing import Union, Optional
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.tools import BaseTool
|
||||
|
@ -13,14 +12,17 @@ from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
|||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class PlanningStrategy(str, enum.Enum):
|
||||
ROUTER = 'router'
|
||||
REACT_ROUTER = 'react_router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
MULTI_FUNCTION_CALL = 'multi_function_call'
|
||||
|
@ -28,10 +30,9 @@ class PlanningStrategy(str, enum.Enum):
|
|||
|
||||
class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
llm: BaseLanguageModel
|
||||
model_instance: BaseLLM
|
||||
tools: list[BaseTool]
|
||||
summary_llm: BaseLanguageModel
|
||||
dataset_llm: BaseLanguageModel
|
||||
summary_model_instance: BaseLLM
|
||||
memory: Optional[BaseChatMemory] = None
|
||||
callbacks: Callbacks = None
|
||||
max_iterations: int = 6
|
||||
|
@ -60,36 +61,49 @@ class AgentExecutor:
|
|||
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
summary_llm=self.configuration.summary_model_instance.client,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
summary_llm=self.configuration.summary_model_instance.client,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
summary_llm=self.configuration.summary_model_instance.client,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||
llm=self.configuration.dataset_llm,
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
verbose=True
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
|
||||
|
||||
|
|
|
@ -10,15 +10,16 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
|
|||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
|
||||
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
|
||||
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.model_name = model_name
|
||||
self.model_instant = model_instant
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
|
@ -152,7 +153,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_name, self._current_loop
|
||||
self._message_agent_thought, self.model_instant, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
|
@ -183,7 +184,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||
)
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_name, self._current_loop
|
||||
self._message_agent_thought, self.model_instant, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
|
|
|
@ -3,18 +3,20 @@ import time
|
|||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
|
||||
from langchain.schema import LLMResult, BaseMessage
|
||||
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
|
||||
class LLMCallbackHandler(BaseCallbackHandler):
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, llm: BaseLanguageModel,
|
||||
def __init__(self, model_instance: BaseLLM,
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
self.llm = llm
|
||||
self.model_instance = model_instance
|
||||
self.llm_message = LLMMessage()
|
||||
self.start_at = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
@ -46,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||
})
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
|
@ -58,7 +60,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||
"text": prompts[0]
|
||||
}]
|
||||
|
||||
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
end_at = time.perf_counter()
|
||||
|
@ -68,7 +70,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
|
||||
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
|
@ -89,7 +91,9 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||
if self.conversation_message_task.streaming:
|
||||
end_at = time.perf_counter()
|
||||
self.llm_message.latency = end_at - self.start_at
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
||||
else:
|
||||
logging.error(error)
|
||||
|
|
|
@ -5,9 +5,7 @@ from typing import Any, Dict, Union
|
|||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.constant import llm_constant
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
|
|
|
@ -2,27 +2,19 @@ import logging
|
|||
import re
|
||||
from typing import Optional, List, Union, Tuple
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms import BaseLLM
|
||||
from langchain.schema import BaseMessage, HumanMessage
|
||||
from langchain.schema import BaseMessage
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.constant import llm_constant
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
|
||||
DifyStdOutCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.llm.error import LLMBadRequestError
|
||||
from core.llm.fake import FakeLLM
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
|
@ -51,12 +43,10 @@ class Completion:
|
|||
|
||||
inputs = conversation.inputs
|
||||
|
||||
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
|
||||
mode=app.mode,
|
||||
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
||||
tenant_id=app.tenant_id,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs
|
||||
model_config=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
conversation_message_task = ConversationMessageTask(
|
||||
|
@ -68,10 +58,17 @@ class Completion:
|
|||
is_override=is_override,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
streaming=streaming
|
||||
streaming=streaming,
|
||||
model_instance=final_model_instance
|
||||
)
|
||||
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
|
||||
mode=app.mode,
|
||||
model_instance=final_model_instance,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs
|
||||
)
|
||||
|
||||
# init orchestrator rule parser
|
||||
orchestrator_rule_parser = OrchestratorRuleParser(
|
||||
|
@ -80,6 +77,7 @@ class Completion:
|
|||
)
|
||||
|
||||
# parse sensitive_word_avoidance_chain
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
|
||||
if sensitive_word_avoidance_chain:
|
||||
query = sensitive_word_avoidance_chain.run(query)
|
||||
|
@ -102,15 +100,14 @@ class Completion:
|
|||
# run the final llm
|
||||
try:
|
||||
cls.run_final_llm(
|
||||
tenant_id=app.tenant_id,
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
streaming=streaming
|
||||
memory=memory
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
return
|
||||
|
@ -121,31 +118,20 @@ class Completion:
|
|||
return
|
||||
|
||||
@classmethod
|
||||
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
fake_response = None
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
|
||||
final_llm = FakeLLM(response=agent_execute_result.output,
|
||||
origin_llm=agent_execute_result.configuration.llm,
|
||||
streaming=streaming)
|
||||
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
|
||||
response = final_llm.generate([[HumanMessage(content=query)]])
|
||||
return response
|
||||
|
||||
final_llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=tenant_id,
|
||||
model=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
fake_response = agent_execute_result.output
|
||||
|
||||
# get llm prompt
|
||||
prompt, stop_words = cls.get_main_llm_prompt(
|
||||
prompt_messages, stop_words = cls.get_main_llm_prompt(
|
||||
mode=mode,
|
||||
llm=final_llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
|
@ -154,25 +140,26 @@ class Completion:
|
|||
memory=memory
|
||||
)
|
||||
|
||||
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=final_llm,
|
||||
model=app_model_config.model_dict,
|
||||
prompt=prompt,
|
||||
mode=mode
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
response = final_llm.generate([prompt], stop_words)
|
||||
response = model_instance.run(
|
||||
messages=prompt_messages,
|
||||
stop=stop_words,
|
||||
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
|
||||
fake_response=fake_response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
|
||||
def get_main_llm_prompt(cls, mode: str, model: dict,
|
||||
pre_prompt: str, query: str, inputs: dict,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
||||
Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
if mode == 'completion':
|
||||
prompt_template = JinjaPromptTemplate.from_template(
|
||||
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
@ -200,11 +187,7 @@ And answer according to the language of the user's question.
|
|||
**prompt_inputs
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
# use chat llm as completion model
|
||||
return [HumanMessage(content=prompt_content)], None
|
||||
else:
|
||||
return prompt_content, None
|
||||
return [PromptMessage(content=prompt_content)], None
|
||||
else:
|
||||
messages: List[BaseMessage] = []
|
||||
|
||||
|
@ -249,12 +232,14 @@ And answer according to the language of the user's question.
|
|||
inputs=human_inputs
|
||||
)
|
||||
|
||||
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
|
||||
model_name = model['name']
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
rest_tokens = llm_constant.max_context_token_length[model_name] \
|
||||
- max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
if memory.model_instance.model_rules.max_tokens.max:
|
||||
curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
else:
|
||||
rest_tokens = 2000
|
||||
|
||||
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
human_message_prompt += "\n\n" if human_message_prompt else ""
|
||||
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
||||
|
@ -274,17 +259,7 @@ And answer according to the language of the user's question.
|
|||
for message in messages:
|
||||
message.content = re.sub(r'<\|.*?\|>', '', message.content)
|
||||
|
||||
return messages, ['\nHuman:', '</histories>']
|
||||
|
||||
@classmethod
|
||||
def get_llm_callbacks(cls, llm: BaseLanguageModel,
|
||||
streaming: bool,
|
||||
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
|
||||
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
||||
if streaming:
|
||||
return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
|
||||
else:
|
||||
return [llm_callback_handler, DifyStdOutCallbackHandler()]
|
||||
return to_prompt_messages(messages), ['\nHuman:', '</histories>']
|
||||
|
||||
@classmethod
|
||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||
|
@ -300,15 +275,15 @@ And answer according to the language of the user's question.
|
|||
conversation: Conversation,
|
||||
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
|
||||
# only for calc token in memory
|
||||
memory_llm = LLMBuilder.to_llm_from_model(
|
||||
memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
||||
tenant_id=tenant_id,
|
||||
model=app_model_config.model_dict
|
||||
model_config=app_model_config.model_dict
|
||||
)
|
||||
|
||||
# use llm config from conversation
|
||||
memory = ReadOnlyConversationTokenDBBufferSharedMemory(
|
||||
conversation=conversation,
|
||||
llm=memory_llm,
|
||||
model_instance=memory_model_instance,
|
||||
max_token_limit=kwargs.get("max_token_limit", 2048),
|
||||
memory_key=kwargs.get("memory_key", "chat_history"),
|
||||
return_messages=kwargs.get("return_messages", True),
|
||||
|
@ -320,21 +295,20 @@ And answer according to the language of the user's question.
|
|||
return memory
|
||||
|
||||
@classmethod
|
||||
def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
|
||||
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict) -> int:
|
||||
llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=tenant_id,
|
||||
model=app_model_config.model_dict
|
||||
)
|
||||
model_limited_tokens = model_instance.model_rules.max_tokens.max
|
||||
max_tokens = model_instance.get_model_kwargs().max_tokens
|
||||
|
||||
model_name = app_model_config.model_dict.get("name")
|
||||
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
|
||||
if model_limited_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
# get prompt without memory and context
|
||||
prompt, _ = cls.get_main_llm_prompt(
|
||||
prompt_messages, _ = cls.get_main_llm_prompt(
|
||||
mode=mode,
|
||||
llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
|
@ -343,9 +317,7 @@ And answer according to the language of the user's question.
|
|||
memory=None
|
||||
)
|
||||
|
||||
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
|
||||
else llm.get_num_tokens_from_messages(prompt)
|
||||
|
||||
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
|
||||
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
|
@ -354,36 +326,40 @@ And answer according to the language of the user's question.
|
|||
return rest_tokens
|
||||
|
||||
@classmethod
|
||||
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
|
||||
prompt: Union[str, List[BaseMessage]], mode: str):
|
||||
def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_name = model.get("name")
|
||||
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
model_limited_tokens = model_instance.model_rules.max_tokens.max
|
||||
max_tokens = model_instance.get_model_kwargs().max_tokens
|
||||
|
||||
if mode == 'completion' and isinstance(final_llm, BaseLLM):
|
||||
prompt_tokens = final_llm.get_num_tokens(prompt)
|
||||
else:
|
||||
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
|
||||
if model_limited_tokens is None:
|
||||
return
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
|
||||
|
||||
if prompt_tokens + max_tokens > model_limited_tokens:
|
||||
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
||||
final_llm.max_tokens = max_tokens
|
||||
|
||||
# update model instance max tokens
|
||||
model_kwargs = model_instance.get_model_kwargs()
|
||||
model_kwargs.max_tokens = max_tokens
|
||||
model_instance.set_model_kwargs(model_kwargs)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
||||
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
||||
|
||||
llm = LLMBuilder.to_llm_from_model(
|
||||
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
||||
tenant_id=app.tenant_id,
|
||||
model=app_model_config.model_dict,
|
||||
model_config=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
# get llm prompt
|
||||
original_prompt, _ = cls.get_main_llm_prompt(
|
||||
old_prompt_messages, _ = cls.get_main_llm_prompt(
|
||||
mode="completion",
|
||||
llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=pre_prompt,
|
||||
query=message.query,
|
||||
|
@ -395,10 +371,9 @@ And answer according to the language of the user's question.
|
|||
original_completion = message.answer.strip()
|
||||
|
||||
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
|
||||
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [HumanMessage(content=prompt)]
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
|
||||
conversation_message_task = ConversationMessageTask(
|
||||
task_id=task_id,
|
||||
|
@ -408,16 +383,16 @@ And answer according to the language of the user's question.
|
|||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
is_override=True if message.override_model_configs else False,
|
||||
streaming=streaming
|
||||
streaming=streaming,
|
||||
model_instance=final_model_instance
|
||||
)
|
||||
|
||||
llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
prompt=prompt,
|
||||
mode='completion'
|
||||
model_instance=final_model_instance,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
llm.generate([prompt])
|
||||
final_model_instance.run(
|
||||
messages=prompt_messages,
|
||||
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
|
||||
)
|
||||
|
|
|
@ -1,109 +0,0 @@
|
|||
from _decimal import Decimal
|
||||
|
||||
models = {
|
||||
'claude-instant-1': 'anthropic', # 100,000 tokens
|
||||
'claude-2': 'anthropic', # 100,000 tokens
|
||||
'gpt-4': 'openai', # 8,192 tokens
|
||||
'gpt-4-32k': 'openai', # 32,768 tokens
|
||||
'gpt-3.5-turbo': 'openai', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k': 'openai', # 16384 tokens
|
||||
'text-davinci-003': 'openai', # 4,097 tokens
|
||||
'text-davinci-002': 'openai', # 4,097 tokens
|
||||
'text-curie-001': 'openai', # 2,049 tokens
|
||||
'text-babbage-001': 'openai', # 2,049 tokens
|
||||
'text-ada-001': 'openai', # 2,049 tokens
|
||||
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
|
||||
'whisper-1': 'openai'
|
||||
}
|
||||
|
||||
max_context_token_length = {
|
||||
'claude-instant-1': 100000,
|
||||
'claude-2': 100000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
'text-davinci-002': 4097,
|
||||
'text-curie-001': 2049,
|
||||
'text-babbage-001': 2049,
|
||||
'text-ada-001': 2049,
|
||||
'text-embedding-ada-002': 8191,
|
||||
}
|
||||
|
||||
models_by_mode = {
|
||||
'chat': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
],
|
||||
'completion': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
'text-davinci-003', # 4,097 tokens
|
||||
'text-davinci-002' # 4,097 tokens
|
||||
'text-curie-001', # 2,049 tokens
|
||||
'text-babbage-001', # 2,049 tokens
|
||||
'text-ada-001' # 2,049 tokens
|
||||
],
|
||||
'embedding': [
|
||||
'text-embedding-ada-002' # 8191 tokens, 1536 dimensions
|
||||
]
|
||||
}
|
||||
|
||||
model_currency = 'USD'
|
||||
|
||||
model_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': Decimal('0.00163'),
|
||||
'completion': Decimal('0.00551'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': Decimal('0.01102'),
|
||||
'completion': Decimal('0.03268'),
|
||||
},
|
||||
'gpt-4': {
|
||||
'prompt': Decimal('0.03'),
|
||||
'completion': Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': Decimal('0.06'),
|
||||
'completion': Decimal('0.12')
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'prompt': Decimal('0.0015'),
|
||||
'completion': Decimal('0.002')
|
||||
},
|
||||
'gpt-3.5-turbo-16k': {
|
||||
'prompt': Decimal('0.003'),
|
||||
'completion': Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': Decimal('0.02'),
|
||||
'completion': Decimal('0.02')
|
||||
},
|
||||
'text-curie-001': {
|
||||
'prompt': Decimal('0.002'),
|
||||
'completion': Decimal('0.002')
|
||||
},
|
||||
'text-babbage-001': {
|
||||
'prompt': Decimal('0.0005'),
|
||||
'completion': Decimal('0.0005')
|
||||
},
|
||||
'text-ada-001': {
|
||||
'prompt': Decimal('0.0004'),
|
||||
'completion': Decimal('0.0004')
|
||||
},
|
||||
'text-embedding-ada-002': {
|
||||
'usage': Decimal('0.0001'),
|
||||
}
|
||||
}
|
||||
|
||||
agent_model_name = 'text-davinci-003'
|
|
@ -6,9 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
|
|||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.constant import llm_constant
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from events.message_event import message_was_created
|
||||
|
@ -16,12 +16,11 @@ from extensions.ext_database import db
|
|||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DatasetQuery
|
||||
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
|
||||
from models.provider import ProviderType, Provider
|
||||
|
||||
|
||||
class ConversationMessageTask:
|
||||
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
||||
inputs: dict, query: str, streaming: bool,
|
||||
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
|
||||
conversation: Optional[Conversation] = None, is_override: bool = False):
|
||||
self.task_id = task_id
|
||||
|
||||
|
@ -38,9 +37,12 @@ class ConversationMessageTask:
|
|||
self.conversation = conversation
|
||||
self.is_new_conversation = False
|
||||
|
||||
self.model_instance = model_instance
|
||||
|
||||
self.message = None
|
||||
|
||||
self.model_dict = self.app_model_config.model_dict
|
||||
self.provider_name = self.model_dict.get('provider')
|
||||
self.model_name = self.model_dict.get('name')
|
||||
self.mode = app.mode
|
||||
|
||||
|
@ -56,9 +58,6 @@ class ConversationMessageTask:
|
|||
)
|
||||
|
||||
def init(self):
|
||||
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
|
||||
self.model_dict['provider'] = provider_name
|
||||
|
||||
override_model_configs = None
|
||||
if self.is_override:
|
||||
override_model_configs = {
|
||||
|
@ -89,15 +88,19 @@ class ConversationMessageTask:
|
|||
if self.app_model_config.pre_prompt:
|
||||
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
|
||||
system_instruction = system_message.content
|
||||
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
|
||||
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_provider_name=self.provider_name,
|
||||
model_name=self.model_name
|
||||
)
|
||||
system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
|
||||
|
||||
if not self.conversation:
|
||||
self.is_new_conversation = True
|
||||
self.conversation = Conversation(
|
||||
app_id=self.app_model_config.app_id,
|
||||
app_model_config_id=self.app_model_config.id,
|
||||
model_provider=self.model_dict.get('provider'),
|
||||
model_provider=self.provider_name,
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=self.mode,
|
||||
|
@ -117,7 +120,7 @@ class ConversationMessageTask:
|
|||
|
||||
self.message = Message(
|
||||
app_id=self.app_model_config.app_id,
|
||||
model_provider=self.model_dict.get('provider'),
|
||||
model_provider=self.provider_name,
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=self.conversation.id,
|
||||
|
@ -131,7 +134,7 @@ class ConversationMessageTask:
|
|||
answer_unit_price=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency=llm_constant.model_currency,
|
||||
currency=self.model_instance.get_currency(),
|
||||
from_source=('console' if isinstance(self.user, Account) else 'api'),
|
||||
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
|
||||
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
|
||||
|
@ -145,12 +148,10 @@ class ConversationMessageTask:
|
|||
self._pub_handler.pub_text(text)
|
||||
|
||||
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
||||
model_name = self.app_model_config.model_dict.get('name')
|
||||
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
message_unit_price = llm_constant.model_prices[model_name]['prompt']
|
||||
answer_unit_price = llm_constant.model_prices[model_name]['completion']
|
||||
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
|
||||
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
|
||||
|
||||
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
|
||||
|
||||
|
@ -163,8 +164,6 @@ class ConversationMessageTask:
|
|||
self.message.provider_response_latency = llm_message.latency
|
||||
self.message.total_price = total_price
|
||||
|
||||
self.update_provider_quota()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
message_was_created.send(
|
||||
|
@ -176,20 +175,6 @@ class ConversationMessageTask:
|
|||
if not by_stopped:
|
||||
self.end()
|
||||
|
||||
def update_provider_quota(self):
|
||||
llm_provider_service = LLMProviderService(
|
||||
tenant_id=self.app.tenant_id,
|
||||
provider_name=self.message.model_provider,
|
||||
)
|
||||
|
||||
provider = llm_provider_service.get_provider_db_record()
|
||||
if provider and provider.provider_type == ProviderType.SYSTEM.value:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.app.tenant_id,
|
||||
Provider.provider_name == provider.provider_name,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + 1})
|
||||
|
||||
def init_chain(self, chain_result: ChainResult):
|
||||
message_chain = MessageChain(
|
||||
message_id=self.message.id,
|
||||
|
@ -229,10 +214,10 @@ class ConversationMessageTask:
|
|||
|
||||
return message_agent_thought
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
|
||||
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
|
||||
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
@ -253,7 +238,7 @@ class ConversationMessageTask:
|
|||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
message_agent_thought.currency = llm_constant.model_currency
|
||||
message_agent_thought.currency = agent_model_instant.get_currency()
|
||||
db.session.flush()
|
||||
|
||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence
|
|||
from langchain.schema import Document
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
@ -13,12 +13,10 @@ class DatesetDocumentStore:
|
|||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
embedding_model_name: str,
|
||||
document_id: Optional[str] = None,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._user_id = user_id
|
||||
self._embedding_model_name = embedding_model_name
|
||||
self._document_id = document_id
|
||||
|
||||
@classmethod
|
||||
|
@ -39,10 +37,6 @@ class DatesetDocumentStore:
|
|||
def user_id(self) -> Any:
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def embedding_model_name(self) -> Any:
|
||||
return self._embedding_model_name
|
||||
|
||||
@property
|
||||
def docs(self) -> Dict[str, Document]:
|
||||
document_segments = db.session.query(DocumentSegment).filter(
|
||||
|
@ -74,6 +68,10 @@ class DatesetDocumentStore:
|
|||
if max_position is None:
|
||||
max_position = 0
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=self._dataset.tenant_id
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError("doc must be a Document")
|
||||
|
@ -88,7 +86,7 @@ class DatesetDocumentStore:
|
|||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
|
||||
tokens = embedding_model.get_num_tokens(doc.page_content)
|
||||
|
||||
if not segment_document:
|
||||
max_position += 1
|
||||
|
|
|
@ -4,14 +4,14 @@ from typing import List
|
|||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
|
||||
|
||||
class CacheEmbedding(Embeddings):
|
||||
def __init__(self, embeddings: Embeddings):
|
||||
def __init__(self, embeddings: BaseEmbedding):
|
||||
self._embeddings = embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
|
@ -21,48 +21,54 @@ class CacheEmbedding(Embeddings):
|
|||
embedding_queue_texts = []
|
||||
for text in texts:
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
|
||||
if embedding:
|
||||
text_embeddings.append(embedding.get_embedding())
|
||||
else:
|
||||
embedding_queue_texts.append(text)
|
||||
|
||||
embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
|
||||
|
||||
i = 0
|
||||
for text in embedding_queue_texts:
|
||||
hash = helper.generate_text_hash(text)
|
||||
|
||||
if embedding_queue_texts:
|
||||
try:
|
||||
embedding = Embedding(hash=hash)
|
||||
embedding.set_embedding(embedding_results[i])
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
continue
|
||||
except:
|
||||
logging.exception('Failed to add embedding to db')
|
||||
continue
|
||||
finally:
|
||||
i += 1
|
||||
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
|
||||
except Exception as ex:
|
||||
raise self._embeddings.handle_exceptions(ex)
|
||||
|
||||
text_embeddings.extend(embedding_results)
|
||||
i = 0
|
||||
for text in embedding_queue_texts:
|
||||
hash = helper.generate_text_hash(text)
|
||||
|
||||
try:
|
||||
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
|
||||
embedding.set_embedding(embedding_results[i])
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
continue
|
||||
except:
|
||||
logging.exception('Failed to add embedding to db')
|
||||
continue
|
||||
finally:
|
||||
i += 1
|
||||
|
||||
text_embeddings.extend(embedding_results)
|
||||
return text_embeddings
|
||||
|
||||
@handle_openai_exceptions
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
|
||||
if embedding:
|
||||
return embedding.get_embedding()
|
||||
|
||||
embedding_results = self._embeddings.embed_query(text)
|
||||
try:
|
||||
embedding_results = self._embeddings.client.embed_query(text)
|
||||
except Exception as ex:
|
||||
raise self._embeddings.handle_exceptions(ex)
|
||||
|
||||
try:
|
||||
embedding = Embedding(hash=hash)
|
||||
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
|
||||
embedding.set_embedding(embedding_results)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
|
@ -72,3 +78,5 @@ class CacheEmbedding(Embeddings):
|
|||
logging.exception('Failed to add embedding to db')
|
||||
|
||||
return embedding_results
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
import logging
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
|
||||
from langchain.schema import OutputParserException
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
|
||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
|
@ -15,9 +12,6 @@ from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTempla
|
|||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
|
||||
GENERATOR_QA_PROMPT
|
||||
|
||||
# gpt-3.5-turbo works not well
|
||||
generate_base_model = 'text-davinci-003'
|
||||
|
||||
|
||||
class LLMGenerator:
|
||||
@classmethod
|
||||
|
@ -28,29 +22,35 @@ class LLMGenerator:
|
|||
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
|
||||
|
||||
prompt = prompt.format(query=query)
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=50,
|
||||
timeout=600
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=50
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [HumanMessage(content=prompt)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_conversation_summary(cls, tenant_id: str, messages):
|
||||
max_tokens = 200
|
||||
model = 'gpt-3.5-turbo'
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
)
|
||||
|
||||
prompt = CONVERSATION_SUMMARY_PROMPT
|
||||
prompt_with_empty_context = prompt.format(context='')
|
||||
prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
|
||||
rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
|
||||
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
|
||||
max_context_token_length = model_instance.model_rules.max_tokens.max
|
||||
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
|
||||
|
||||
context = ''
|
||||
for message in messages:
|
||||
|
@ -68,25 +68,16 @@ class LLMGenerator:
|
|||
answer = message.answer
|
||||
|
||||
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
|
||||
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
|
||||
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
|
||||
context += message_qa_text
|
||||
|
||||
if not context:
|
||||
return '[message too long, no summary]'
|
||||
|
||||
prompt = prompt.format(context=context)
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=model,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [HumanMessage(content=prompt)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
|
@ -94,16 +85,13 @@ class LLMGenerator:
|
|||
prompt = INTRODUCTION_GENERATE_PROMPT
|
||||
prompt = prompt.format(prompt=pre_prompt)
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [HumanMessage(content=prompt)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
|
@ -119,23 +107,19 @@ class LLMGenerator:
|
|||
|
||||
_input = prompt.format_prompt(histories=histories)
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
temperature=0,
|
||||
max_tokens=256
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=256,
|
||||
temperature=0
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
query = [HumanMessage(content=_input.to_string())]
|
||||
else:
|
||||
query = _input.to_string()
|
||||
prompts = [PromptMessage(content=_input.to_string())]
|
||||
|
||||
try:
|
||||
output = llm(query)
|
||||
if isinstance(output, BaseMessage):
|
||||
output = output.content
|
||||
questions = output_parser.parse(output)
|
||||
output = model_instance.run(prompts)
|
||||
questions = output_parser.parse(output.content)
|
||||
except Exception:
|
||||
logging.exception("Error generating suggested questions after answer")
|
||||
questions = []
|
||||
|
@ -160,21 +144,19 @@ class LLMGenerator:
|
|||
|
||||
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
temperature=0,
|
||||
max_tokens=512
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=512,
|
||||
temperature=0
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
query = [HumanMessage(content=_input.to_string())]
|
||||
else:
|
||||
query = _input.to_string()
|
||||
prompts = [PromptMessage(content=_input.to_string())]
|
||||
|
||||
try:
|
||||
output = llm(query)
|
||||
rule_config = output_parser.parse(output)
|
||||
output = model_instance.run(prompts)
|
||||
rule_config = output_parser.parse(output.content)
|
||||
except OutputParserException:
|
||||
raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
|
||||
except Exception:
|
||||
|
@ -188,25 +170,21 @@ class LLMGenerator:
|
|||
return rule_config
|
||||
|
||||
@classmethod
|
||||
async def generate_qa_document(cls, llm: StreamableOpenAI, query):
|
||||
def generate_qa_document(cls, tenant_id: str, query):
|
||||
prompt = GENERATOR_QA_PROMPT
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=2000
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
|
||||
prompts = [
|
||||
PromptMessage(content=prompt, type=MessageType.SYSTEM),
|
||||
PromptMessage(content=query)
|
||||
]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
|
||||
prompt = GENERATOR_QA_PROMPT
|
||||
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
return answer.strip()
|
||||
|
|
20
api/core/helper/encrypter.py
Normal file
20
api/core/helper/encrypter.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import base64
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
def obfuscated_token(token: str):
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
|
||||
def decrypt_token(tenant_id: str, token: str):
|
||||
return rsa.decrypt(base64.b64decode(token), tenant_id)
|
|
@ -1,10 +1,9 @@
|
|||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
|
@ -15,16 +14,11 @@ class IndexBuilder:
|
|||
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
|
||||
return None
|
||||
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||
max_retries=1,
|
||||
**model_credentials
|
||||
))
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
return VectorIndex(
|
||||
dataset=dataset,
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import concurrent
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
|
@ -6,7 +5,6 @@ import re
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, List, cast
|
||||
|
||||
from flask_login import current_user
|
||||
|
@ -18,11 +16,10 @@ from core.data_loader.loader.notion import NotionLoader
|
|||
from core.docstore.dataset_docstore import DatesetDocumentStore
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.index.index import IndexBuilder
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import MessageType
|
||||
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
@ -35,9 +32,8 @@ from models.source import DataSourceBinding
|
|||
|
||||
class IndexingRunner:
|
||||
|
||||
def __init__(self, embedding_model_name: str = "text-embedding-ada-002"):
|
||||
def __init__(self):
|
||||
self.storage = storage
|
||||
self.embedding_model_name = embedding_model_name
|
||||
|
||||
def run(self, dataset_documents: List[DatasetDocument]):
|
||||
"""Run the indexing process."""
|
||||
|
@ -227,11 +223,15 @@ class IndexingRunner:
|
|||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
|
@ -253,44 +253,49 @@ class IndexingRunner:
|
|||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
total_segments += len(documents)
|
||||
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
|
||||
self.filter_string(document.page_content))
|
||||
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
||||
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# load data from notion
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
|
@ -336,31 +341,31 @@ class IndexingRunner:
|
|||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
|
||||
tokens += embedding_model.get_num_tokens(document.page_content)
|
||||
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
|
@ -459,7 +464,6 @@ class IndexingRunner:
|
|||
doc_store = DatesetDocumentStore(
|
||||
dataset=dataset,
|
||||
user_id=dataset_document.created_by,
|
||||
embedding_model_name=self.embedding_model_name,
|
||||
document_id=dataset_document.id
|
||||
)
|
||||
|
||||
|
@ -513,17 +517,12 @@ class IndexingRunner:
|
|||
all_documents.extend(split_documents)
|
||||
# processing qa document
|
||||
if document_form == 'qa_model':
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
for i in range(0, len(all_documents), 10):
|
||||
threads = []
|
||||
sub_documents = all_documents[i:i + 10]
|
||||
for doc in sub_documents:
|
||||
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
|
||||
'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents})
|
||||
'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
|
@ -531,13 +530,13 @@ class IndexingRunner:
|
|||
return all_qa_documents
|
||||
return all_documents
|
||||
|
||||
def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents):
|
||||
def format_qa_document(self, tenant_id: str, document_node, all_qa_documents):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
return
|
||||
try:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
|
||||
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
|
@ -638,6 +637,10 @@ class IndexingRunner:
|
|||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
)
|
||||
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
|
@ -648,7 +651,7 @@ class IndexingRunner:
|
|||
chunk_documents = documents[i:i + chunk_size]
|
||||
|
||||
tokens += sum(
|
||||
TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
|
||||
embedding_model.get_num_tokens(document.page_content)
|
||||
for document in chunk_documents
|
||||
)
|
||||
|
||||
|
|
|
@ -1,148 +0,0 @@
|
|||
from typing import Union, Optional, List
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
|
||||
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
|
||||
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from models.provider import ProviderType, ProviderName
|
||||
|
||||
|
||||
class LLMBuilder:
|
||||
"""
|
||||
This class handles the following logic:
|
||||
1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
|
||||
2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
|
||||
OPENAI_API_TYPE=azure
|
||||
OPENAI_API_VERSION=2022-12-01
|
||||
OPENAI_API_BASE=https://your-resource-name.openai.azure.com
|
||||
OPENAI_API_KEY=<your Azure OpenAI API key>
|
||||
3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
|
||||
4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
|
||||
5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
|
||||
6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
|
||||
7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
|
||||
8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
provider = cls.get_default_provider(tenant_id, model_name)
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||
|
||||
llm_cls = None
|
||||
mode = cls.get_mode_by_model(model_name)
|
||||
if mode == 'chat':
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
llm_cls = StreamableChatOpenAI
|
||||
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||
llm_cls = StreamableAzureChatOpenAI
|
||||
elif provider == ProviderName.ANTHROPIC.value:
|
||||
llm_cls = StreamableChatAnthropic
|
||||
elif mode == 'completion':
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
llm_cls = StreamableOpenAI
|
||||
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||
llm_cls = StreamableAzureOpenAI
|
||||
|
||||
if not llm_cls:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
model_kwargs = {
|
||||
'model_name': model_name,
|
||||
'temperature': kwargs.get('temperature', 0),
|
||||
'max_tokens': kwargs.get('max_tokens', 256),
|
||||
'top_p': kwargs.get('top_p', 1),
|
||||
'frequency_penalty': kwargs.get('frequency_penalty', 0),
|
||||
'presence_penalty': kwargs.get('presence_penalty', 0),
|
||||
'callbacks': kwargs.get('callbacks', None),
|
||||
'streaming': kwargs.get('streaming', False),
|
||||
}
|
||||
|
||||
model_kwargs.update(model_credentials)
|
||||
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
|
||||
|
||||
return llm_cls(**model_kwargs)
|
||||
|
||||
@classmethod
|
||||
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
model_name = model.get("name")
|
||||
completion_params = model.get("completion_params", {})
|
||||
|
||||
return cls.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=model_name,
|
||||
temperature=completion_params.get('temperature', 0),
|
||||
max_tokens=completion_params.get('max_tokens', 256),
|
||||
top_p=completion_params.get('top_p', 0),
|
||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
||||
presence_penalty=completion_params.get('presence_penalty', 0.1),
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mode_by_model(cls, model_name: str) -> str:
|
||||
if not model_name:
|
||||
raise ValueError(f"empty model name is not supported.")
|
||||
|
||||
if model_name in llm_constant.models_by_mode['chat']:
|
||||
return "chat"
|
||||
elif model_name in llm_constant.models_by_mode['completion']:
|
||||
return "completion"
|
||||
else:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
@classmethod
|
||||
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
|
||||
"""
|
||||
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
|
||||
Raises an exception if the model_name is not found or if the provider is not found.
|
||||
"""
|
||||
if not model_name:
|
||||
raise Exception('model name not found')
|
||||
#
|
||||
# if model_name not in llm_constant.models:
|
||||
# raise Exception('model {} not found'.format(model_name))
|
||||
|
||||
# model_provider = llm_constant.models[model_name]
|
||||
|
||||
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
|
||||
return provider_service.get_credentials(model_name)
|
||||
|
||||
@classmethod
|
||||
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
|
||||
provider_name = llm_constant.models[model_name]
|
||||
|
||||
if provider_name == 'openai':
|
||||
# get the default provider (openai / azure_openai) for the tenant
|
||||
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
|
||||
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||
|
||||
provider = None
|
||||
if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
|
||||
provider = azure_openai_provider
|
||||
elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider = azure_openai_provider
|
||||
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {provider_name} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
provider_name = provider.provider_name
|
||||
|
||||
return provider_name
|
|
@ -1,15 +0,0 @@
|
|||
import openai
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class Moderation:
|
||||
|
||||
def __init__(self, provider: str, api_key: str):
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
|
||||
if self.provider == ProviderName.OPENAI.value:
|
||||
self.client = openai.Moderation
|
||||
|
||||
def moderate(self, text):
|
||||
return self.client.create(input=text, api_key=self.api_key)
|
|
@ -1,138 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import anthropic
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName, ProviderType
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
},
|
||||
{
|
||||
'id': 'claude-2',
|
||||
'name': 'claude-2',
|
||||
},
|
||||
]
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
return self.get_provider_api_key(model_id=model_id)
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.ANTHROPIC
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = {
|
||||
'anthropic_api_key': ''
|
||||
}
|
||||
|
||||
if obfuscated:
|
||||
if not config.get('anthropic_api_key'):
|
||||
config = {
|
||||
'anthropic_api_key': ''
|
||||
}
|
||||
|
||||
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
|
||||
return config
|
||||
|
||||
return config
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
Returns the encrypted token.
|
||||
"""
|
||||
return json.dumps({
|
||||
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
|
||||
})
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
"""
|
||||
Returns the decrypted token.
|
||||
"""
|
||||
config = json.loads(token)
|
||||
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
return dict
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
# check OpenAI / Azure OpenAI credential is valid
|
||||
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
|
||||
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||
|
||||
provider = None
|
||||
if openai_provider:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider:
|
||||
provider = azure_openai_provider
|
||||
|
||||
if not provider:
|
||||
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
|
||||
if quota_used >= quota_limit:
|
||||
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
|
||||
f"please configure OpenAI or Azure OpenAI provider first.")
|
||||
|
||||
try:
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError('Config must be a object.')
|
||||
|
||||
if 'anthropic_api_key' not in config:
|
||||
raise ValueError('anthropic_api_key must be provided.')
|
||||
|
||||
chat_llm = ChatAnthropic(
|
||||
model='claude-instant-1',
|
||||
anthropic_api_key=config['anthropic_api_key'],
|
||||
max_tokens_to_sample=10,
|
||||
temperature=0,
|
||||
default_request_timeout=60
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="ping"
|
||||
)
|
||||
]
|
||||
|
||||
chat_llm(messages)
|
||||
except anthropic.APIConnectionError as ex:
|
||||
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
|
||||
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
|
||||
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
|
||||
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
|
||||
except Exception as ex:
|
||||
logging.exception('Anthropic config validation failed')
|
||||
raise ex
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}
|
|
@ -1,145 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
class AzureProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
|
||||
return []
|
||||
|
||||
def check_embedding_model(self, credentials: Optional[dict] = None):
|
||||
credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
|
||||
try:
|
||||
result = openai.Embedding.create(input=['test'],
|
||||
engine='text-embedding-ada-002',
|
||||
timeout=60,
|
||||
api_key=str(credentials.get('openai_api_key')),
|
||||
api_base=str(credentials.get('openai_api_base')),
|
||||
api_type='azure',
|
||||
api_version=str(credentials.get('openai_api_version')))["data"][0][
|
||||
"embedding"]
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise AzureAuthenticationError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
raise AzureRequestFailedError(
|
||||
'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
|
||||
except openai.error.InvalidRequestError as e:
|
||||
if e.http_status == 404:
|
||||
raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
|
||||
"deployment name is exists in Azure AI")
|
||||
else:
|
||||
raise AzureRequestFailedError(
|
||||
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
|
||||
except openai.error.OpenAIError as e:
|
||||
raise AzureRequestFailedError(
|
||||
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
|
||||
|
||||
if not isinstance(result, list):
|
||||
raise AzureRequestFailedError('Failed to request Azure OpenAI.')
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary.
|
||||
"""
|
||||
config = self.get_provider_api_key(model_id=model_id)
|
||||
config['openai_api_type'] = 'azure'
|
||||
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
|
||||
if model_id == 'text-embedding-ada-002':
|
||||
config['deployment'] = model_id.replace('.', '') if model_id else None
|
||||
config['chunk_size'] = 16
|
||||
else:
|
||||
config['deployment_name'] = model_id.replace('.', '') if model_id else None
|
||||
return config
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.AZURE_OPENAI
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
if obfuscated:
|
||||
if not config.get('openai_api_key'):
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
|
||||
return config
|
||||
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
return dict
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
try:
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError('Config must be a object.')
|
||||
|
||||
if 'openai_api_version' not in config:
|
||||
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
|
||||
|
||||
self.check_embedding_model(credentials=config)
|
||||
except ValidateFailedError as e:
|
||||
raise e
|
||||
except AzureAuthenticationError:
|
||||
raise ValidateFailedError('Validation failed, please check your API Key.')
|
||||
except AzureRequestFailedError as ex:
|
||||
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
|
||||
except Exception as ex:
|
||||
logging.exception('Azure OpenAI Credentials validation failed')
|
||||
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
Returns the encrypted token.
|
||||
"""
|
||||
return json.dumps({
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': config['openai_api_base'],
|
||||
'openai_api_key': self.encrypt_token(config['openai_api_key'])
|
||||
})
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
"""
|
||||
Returns the decrypted token.
|
||||
"""
|
||||
config = json.loads(token)
|
||||
config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
|
||||
return config
|
||||
|
||||
|
||||
class AzureAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AzureRequestFailedError(Exception):
|
||||
pass
|
|
@ -1,132 +0,0 @@
|
|||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.account import Tenant
|
||||
from models.provider import Provider, ProviderType, ProviderName
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the decrypted API key for the given tenant_id and provider_name.
|
||||
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
||||
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
|
||||
"""
|
||||
provider = self.get_provider(only_custom)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
|
||||
|
||||
if model_id and model_id == 'gpt-4':
|
||||
raise ModelCurrentlyNotSupportError()
|
||||
|
||||
if quota_used >= quota_limit:
|
||||
raise QuotaExceededError()
|
||||
|
||||
return self.get_hosted_credentials()
|
||||
else:
|
||||
return self.get_decrypted_token(provider.encrypted_config)
|
||||
|
||||
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
"""
|
||||
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
|
||||
|
||||
@classmethod
|
||||
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
|
||||
Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist.
|
||||
"""
|
||||
query = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id
|
||||
)
|
||||
|
||||
if provider_name:
|
||||
query = query.filter(Provider.provider_name == provider_name)
|
||||
|
||||
if only_custom:
|
||||
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
|
||||
|
||||
providers = query.order_by(Provider.provider_type.asc()).all()
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
|
||||
return provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = ''
|
||||
|
||||
if obfuscated:
|
||||
return self.obfuscated_token(config)
|
||||
|
||||
return config
|
||||
|
||||
def obfuscated_token(self, token: str):
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
|
||||
def get_token_type(self):
|
||||
return str
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
return self.encrypt_token(config)
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
return self.decrypt_token(token)
|
||||
|
||||
def encrypt_token(self, token):
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
def decrypt_token(self, token):
|
||||
return rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def config_validate(self, config: str):
|
||||
raise NotImplementedError
|
|
@ -1,2 +0,0 @@
|
|||
class ValidateFailedError(Exception):
|
||||
description = "Provider Validate failed"
|
|
@ -1,22 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class HuggingfaceProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
# todo
|
||||
return []
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
|
||||
"""
|
||||
return {
|
||||
'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.HUGGINGFACEHUB
|
|
@ -1,53 +0,0 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
from core.llm.provider.anthropic_provider import AnthropicProvider
|
||||
from core.llm.provider.azure_provider import AzureProvider
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.huggingface_provider import HuggingfaceProvider
|
||||
from core.llm.provider.openai_provider import OpenAIProvider
|
||||
from models.provider import Provider
|
||||
|
||||
|
||||
class LLMProviderService:
|
||||
|
||||
def __init__(self, tenant_id: str, provider_name: str):
|
||||
self.provider = self.init_provider(tenant_id, provider_name)
|
||||
|
||||
def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
|
||||
if provider_name == 'openai':
|
||||
return OpenAIProvider(tenant_id)
|
||||
elif provider_name == 'azure_openai':
|
||||
return AzureProvider(tenant_id)
|
||||
elif provider_name == 'anthropic':
|
||||
return AnthropicProvider(tenant_id)
|
||||
elif provider_name == 'huggingface':
|
||||
return HuggingfaceProvider(tenant_id)
|
||||
else:
|
||||
raise Exception('provider {} not found'.format(provider_name))
|
||||
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
return self.provider.get_models(model_id)
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
return self.provider.get_credentials(model_id)
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
|
||||
|
||||
def get_provider_db_record(self) -> Optional[Provider]:
|
||||
return self.provider.get_provider()
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
|
||||
:param config:
|
||||
:raises: ValidateFailedError
|
||||
"""
|
||||
return self.provider.config_validate(config)
|
||||
|
||||
def get_token_type(self):
|
||||
return self.provider.get_token_type()
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
return self.provider.get_encrypted_token(config)
|
|
@ -1,55 +0,0 @@
|
|||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import openai
|
||||
from openai.error import AuthenticationError, OpenAIError
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.moderation import Moderation
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
response = openai.Model.list(**credentials)
|
||||
|
||||
return [{
|
||||
'id': model['id'],
|
||||
'name': model['id'],
|
||||
} for model in response['data']]
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the credentials for the given tenant_id and provider_name.
|
||||
"""
|
||||
return {
|
||||
'openai_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.OPENAI
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
try:
|
||||
Moderation(self.get_provider_name().value, config).moderate('test')
|
||||
except (AuthenticationError, OpenAIError) as ex:
|
||||
raise ValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('OpenAI config validation failed')
|
||||
raise ex
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
return hosted_llm_credentials.openai.api_key
|
|
@ -1,62 +0,0 @@
|
|||
from typing import List, Optional, Any, Dict
|
||||
|
||||
from httpx import Timeout
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
|
||||
|
||||
|
||||
class StreamableChatAnthropic(ChatAnthropic):
|
||||
"""
|
||||
Wrapper around Anthropic's large language model.
|
||||
"""
|
||||
|
||||
default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
|
||||
|
||||
@root_validator()
|
||||
def prepare_params(cls, values: Dict) -> Dict:
|
||||
values['model_name'] = values.get('model')
|
||||
values['max_tokens'] = values.get('max_tokens_to_sample')
|
||||
return values
|
||||
|
||||
@handle_anthropic_exceptions
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
params['model'] = params.get('model_name')
|
||||
del params['model_name']
|
||||
|
||||
params['max_tokens_to_sample'] = params.get('max_tokens')
|
||||
del params['max_tokens']
|
||||
|
||||
del params['frequency_penalty']
|
||||
del params['presence_penalty']
|
||||
|
||||
return params
|
||||
|
||||
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} {message.content}"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{self.AI_PROMPT} {message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"<admin>{message.content}</admin>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
|
@ -1,41 +0,0 @@
|
|||
import decimal
|
||||
from typing import Optional
|
||||
|
||||
import tiktoken
|
||||
|
||||
from core.constant import llm_constant
|
||||
|
||||
|
||||
class TokenCalculator:
|
||||
@classmethod
|
||||
def get_num_tokens(cls, model_name: str, text: str):
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
@classmethod
|
||||
def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal:
|
||||
if model_name in llm_constant.models_by_mode['embedding']:
|
||||
unit_price = llm_constant.model_prices[model_name]['usage']
|
||||
elif text_type == 'prompt':
|
||||
unit_price = llm_constant.model_prices[model_name]['prompt']
|
||||
elif text_type == 'completion':
|
||||
unit_price = llm_constant.model_prices[model_name]['completion']
|
||||
else:
|
||||
raise Exception('Invalid text type')
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
@classmethod
|
||||
def get_currency(cls, model_name: str):
|
||||
return llm_constant.model_currency
|
|
@ -1,26 +0,0 @@
|
|||
import openai
|
||||
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
from models.provider import ProviderName
|
||||
from core.llm.provider.base import BaseProvider
|
||||
|
||||
|
||||
class Whisper:
|
||||
|
||||
def __init__(self, provider: BaseProvider):
|
||||
self.provider = provider
|
||||
|
||||
if self.provider.get_provider_name() == ProviderName.OPENAI:
|
||||
self.client = openai.Audio
|
||||
self.credentials = provider.get_credentials()
|
||||
|
||||
@handle_openai_exceptions
|
||||
def transcribe(self, file):
|
||||
return self.client.transcribe(
|
||||
model='whisper-1',
|
||||
file=file,
|
||||
api_key=self.credentials.get('openai_api_key'),
|
||||
api_base=self.credentials.get('openai_api_base'),
|
||||
api_type=self.credentials.get('openai_api_type'),
|
||||
api_version=self.credentials.get('openai_api_version'),
|
||||
)
|
|
@ -1,27 +0,0 @@
|
|||
import logging
|
||||
from functools import wraps
|
||||
|
||||
import anthropic
|
||||
|
||||
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_anthropic_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except anthropic.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to Anthropic API.")
|
||||
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
|
||||
except anthropic.RateLimitError:
|
||||
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(f"Anthropic: {e.message}")
|
||||
except anthropic.BadRequestError as e:
|
||||
raise LLMBadRequestError(f"Anthropic: {e.message}")
|
||||
except anthropic.APIStatusError as e:
|
||||
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
|
||||
|
||||
return wrapper
|
|
@ -1,31 +0,0 @@
|
|||
import logging
|
||||
from functools import wraps
|
||||
|
||||
import openai
|
||||
|
||||
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_openai_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except openai.error.InvalidRequestError as e:
|
||||
logging.exception("Invalid request to OpenAI API.")
|
||||
raise LLMBadRequestError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to OpenAI API.")
|
||||
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
|
||||
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
|
||||
logging.exception("OpenAI service unavailable.")
|
||||
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
|
||||
except openai.error.RateLimitError as e:
|
||||
raise LLMRateLimitError(str(e))
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(str(e))
|
||||
except openai.error.OpenAIError as e:
|
||||
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
|
||||
|
||||
return wrapper
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Any, List, Dict, Union
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
|
||||
from langchain.schema import get_buffer_string, BaseMessage
|
||||
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
@ -13,7 +13,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
|||
conversation: Conversation
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "Assistant"
|
||||
llm: BaseLanguageModel
|
||||
model_instance: BaseLLM
|
||||
memory_key: str = "chat_history"
|
||||
max_token_limit: int = 2000
|
||||
message_limit: int = 10
|
||||
|
@ -29,23 +29,23 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
|||
|
||||
messages = list(reversed(messages))
|
||||
|
||||
chat_messages: List[BaseMessage] = []
|
||||
chat_messages: List[PromptMessage] = []
|
||||
for message in messages:
|
||||
chat_messages.append(HumanMessage(content=message.query))
|
||||
chat_messages.append(AIMessage(content=message.answer))
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
|
||||
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
|
||||
|
||||
if not chat_messages:
|
||||
return chat_messages
|
||||
return []
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit and chat_messages:
|
||||
pruned_memory.append(chat_messages.pop(0))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
|
||||
|
||||
return chat_messages
|
||||
return to_lc_messages(chat_messages)
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
|
|
293
api/core/model_providers/model_factory.py
Normal file
293
api/core/model_providers/model_factory.py
Normal file
|
@ -0,0 +1,293 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.callbacks.base import Callbacks
|
||||
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.speech2text.base import BaseSpeech2Text
|
||||
from extensions.ext_database import db
|
||||
from models.provider import TenantDefaultModel
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
|
||||
@classmethod
|
||||
def get_text_generation_model_from_model_config(cls, tenant_id: str,
|
||||
model_config: dict,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None) -> Optional[BaseLLM]:
|
||||
provider_name = model_config.get("provider")
|
||||
model_name = model_config.get("name")
|
||||
completion_params = model_config.get("completion_params", {})
|
||||
|
||||
return cls.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=completion_params.get('temperature', 0),
|
||||
max_tokens=completion_params.get('max_tokens', 256),
|
||||
top_p=completion_params.get('top_p', 0),
|
||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
||||
presence_penalty=completion_params.get('presence_penalty', 0.1)
|
||||
),
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_text_generation_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
model_kwargs: Optional[ModelKwargs] = None,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None) -> Optional[BaseLLM]:
|
||||
"""
|
||||
get text generation model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:param model_kwargs:
|
||||
:param streaming:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
is_default_model = False
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default System Reasoning Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
is_default_model = True
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init text generation model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
|
||||
|
||||
try:
|
||||
model_instance = model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
except LLMBadRequestError as e:
|
||||
if is_default_model:
|
||||
raise LLMBadRequestError(f"Default model {model_name} is not available. "
|
||||
f"Please check your model provider credentials.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
if is_default_model:
|
||||
model_instance.deduct_quota = False
|
||||
|
||||
return model_instance
|
||||
|
||||
@classmethod
|
||||
def get_embedding_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
|
||||
"""
|
||||
get embedding model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default Embedding Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init embedding model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_speech2text_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
|
||||
"""
|
||||
get speech to text model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default Speech-to-Text Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init speech to text model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_moderation_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: str,
|
||||
model_name: str) -> Optional[BaseProviderModel]:
|
||||
"""
|
||||
get moderation model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init moderation model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
|
||||
"""
|
||||
get default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
# get default model
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.value
|
||||
).first()
|
||||
|
||||
if not default_model:
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
||||
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
if not model_provider:
|
||||
continue
|
||||
|
||||
model_list = model_provider.get_supported_model_list(model_type)
|
||||
if model_list:
|
||||
model_info = model_list[0]
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.value,
|
||||
provider_name=model_provider_name,
|
||||
model_name=model_info['id']
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
break
|
||||
|
||||
return default_model
|
||||
|
||||
@classmethod
|
||||
def update_default_model(cls,
|
||||
tenant_id: str,
|
||||
model_type: ModelType,
|
||||
provider_name: str,
|
||||
model_name: str) -> TenantDefaultModel:
|
||||
"""
|
||||
update default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
model_provider_name = ModelProviderFactory.get_provider_names()
|
||||
if provider_name not in model_provider_name:
|
||||
raise ValueError(f'Invalid provider name: {provider_name}')
|
||||
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
model_list = model_provider.get_supported_model_list(model_type)
|
||||
model_ids = [model['id'] for model in model_list]
|
||||
if model_name not in model_ids:
|
||||
raise ValueError(f'Invalid model name: {model_name}')
|
||||
|
||||
# get default model
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.value
|
||||
).first()
|
||||
|
||||
if default_model:
|
||||
# update default model
|
||||
default_model.provider_name = provider_name
|
||||
default_model.model_name = model_name
|
||||
db.session.commit()
|
||||
else:
|
||||
# create default model
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.value,
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
|
||||
return default_model
|
228
api/core/model_providers/model_provider_factory.py
Normal file
228
api/core/model_providers/model_provider_factory.py
Normal file
|
@ -0,0 +1,228 @@
|
|||
from typing import Type
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.rules import provider_rules
|
||||
from extensions.ext_database import db
|
||||
from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
|
||||
|
||||
DEFAULT_MODELS = {
|
||||
ModelType.TEXT_GENERATION.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'gpt-3.5-turbo',
|
||||
},
|
||||
ModelType.EMBEDDINGS.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'text-embedding-ada-002',
|
||||
},
|
||||
ModelType.SPEECH_TO_TEXT.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'whisper-1',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
@classmethod
|
||||
def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
|
||||
if provider_name == 'openai':
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
return OpenAIProvider
|
||||
elif provider_name == 'anthropic':
|
||||
from core.model_providers.providers.anthropic_provider import AnthropicProvider
|
||||
return AnthropicProvider
|
||||
elif provider_name == 'minimax':
|
||||
from core.model_providers.providers.minimax_provider import MinimaxProvider
|
||||
return MinimaxProvider
|
||||
elif provider_name == 'spark':
|
||||
from core.model_providers.providers.spark_provider import SparkProvider
|
||||
return SparkProvider
|
||||
elif provider_name == 'tongyi':
|
||||
from core.model_providers.providers.tongyi_provider import TongyiProvider
|
||||
return TongyiProvider
|
||||
elif provider_name == 'wenxin':
|
||||
from core.model_providers.providers.wenxin_provider import WenxinProvider
|
||||
return WenxinProvider
|
||||
elif provider_name == 'chatglm':
|
||||
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
|
||||
return ChatGLMProvider
|
||||
elif provider_name == 'azure_openai':
|
||||
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
return AzureOpenAIProvider
|
||||
elif provider_name == 'replicate':
|
||||
from core.model_providers.providers.replicate_provider import ReplicateProvider
|
||||
return ReplicateProvider
|
||||
elif provider_name == 'huggingface_hub':
|
||||
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
|
||||
return HuggingfaceHubProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_provider_names(cls):
|
||||
"""
|
||||
Returns a list of provider names.
|
||||
"""
|
||||
return list(provider_rules.keys())
|
||||
|
||||
@classmethod
|
||||
def get_provider_rules(cls):
|
||||
"""
|
||||
Returns a list of provider rules.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return provider_rules
|
||||
|
||||
@classmethod
|
||||
def get_provider_rule(cls, provider_name: str):
|
||||
"""
|
||||
Returns provider rule.
|
||||
"""
|
||||
return provider_rules[provider_name]
|
||||
|
||||
@classmethod
|
||||
def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred model provider.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
# get preferred provider
|
||||
preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
|
||||
if not preferred_provider or not preferred_provider.is_valid:
|
||||
return None
|
||||
|
||||
# init model provider
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
|
||||
return model_provider_class(provider=preferred_provider)
|
||||
|
||||
@classmethod
|
||||
def get_preferred_type_by_preferred_model_provider(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: str,
|
||||
preferred_model_provider: TenantPreferredModelProvider):
|
||||
"""
|
||||
get preferred provider type by preferred model provider.
|
||||
|
||||
:param model_provider_name:
|
||||
:param preferred_model_provider:
|
||||
:return:
|
||||
"""
|
||||
if not preferred_model_provider:
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
||||
support_provider_types = model_provider_rules['support_provider_types']
|
||||
|
||||
if ProviderType.CUSTOM.value in support_provider_types:
|
||||
custom_provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.is_valid == True
|
||||
).first()
|
||||
|
||||
if custom_provider:
|
||||
return ProviderType.CUSTOM.value
|
||||
|
||||
model_provider = cls.get_model_provider_class(model_provider_name)
|
||||
|
||||
if ProviderType.SYSTEM.value in support_provider_types \
|
||||
and model_provider.is_provider_type_system_supported():
|
||||
return ProviderType.SYSTEM.value
|
||||
elif ProviderType.CUSTOM.value in support_provider_types:
|
||||
return ProviderType.CUSTOM.value
|
||||
else:
|
||||
return preferred_model_provider.preferred_provider_type
|
||||
|
||||
@classmethod
|
||||
def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred provider of tenant.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
# get preferred provider type
|
||||
preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
|
||||
|
||||
# get providers by preferred provider type
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == preferred_provider_type
|
||||
).all()
|
||||
|
||||
no_system_provider = False
|
||||
if preferred_provider_type == ProviderType.SYSTEM.value:
|
||||
quota_type_to_provider_dict = {}
|
||||
for provider in providers:
|
||||
quota_type_to_provider_dict[provider.quota_type] = provider
|
||||
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
||||
for quota_type_enum in ProviderQuotaType:
|
||||
quota_type = quota_type_enum.value
|
||||
if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
|
||||
and quota_type in quota_type_to_provider_dict.keys():
|
||||
provider = quota_type_to_provider_dict[quota_type]
|
||||
if provider.is_valid and provider.quota_limit > provider.quota_used:
|
||||
return provider
|
||||
|
||||
no_system_provider = True
|
||||
|
||||
if no_system_provider:
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).all()
|
||||
|
||||
if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
|
||||
if providers:
|
||||
return providers[0]
|
||||
else:
|
||||
try:
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=model_provider_name,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred provider type of tenant.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == model_provider_name
|
||||
).first()
|
||||
|
||||
return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)
|
22
api/core/model_providers/models/base.py
Normal file
22
api/core/model_providers/models/base.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
from abc import ABC
|
||||
from typing import Any
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class BaseProviderModel(ABC):
|
||||
_client: Any
|
||||
_model_provider: BaseModelProvider
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, client: Any):
|
||||
self._model_provider = model_provider
|
||||
self._client = client
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def model_provider(self):
|
||||
return self._model_provider
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
import decimal
|
||||
import logging
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \
|
||||
LLMAPIUnavailableError, LLMAPIConnectionError
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
self.credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = OpenAIEmbeddings(
|
||||
deployment=name,
|
||||
openai_api_type='azure',
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
chunk_size=16,
|
||||
max_retries=1,
|
||||
**self.credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name'))
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to Azure OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to Azure OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("Azure OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
40
api/core/model_providers/models/embedding/base.py
Normal file
40
api/core/model_providers/models/embedding/base.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema.language_model import _get_token_ids_default_method
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class BaseEmbedding(BaseProviderModel):
|
||||
name: str
|
||||
type: ModelType = ModelType.EMBEDDINGS
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
return len(_get_token_ids_default_method(text))
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return 0
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,35 @@
|
|||
import decimal
|
||||
import logging
|
||||
|
||||
from langchain.embeddings import MiniMaxEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class MinimaxEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = MiniMaxEmbeddings(
|
||||
model=name,
|
||||
**credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
else:
|
||||
return ex
|
|
@ -0,0 +1,72 @@
|
|||
import decimal
|
||||
import logging
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class OpenAIEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = OpenAIEmbeddings(
|
||||
max_retries=1,
|
||||
**credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.name)
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
|
@ -0,0 +1,36 @@
|
|||
import decimal
|
||||
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class ReplicateEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = ReplicateEmbeddings(
|
||||
model=name + ':' + credentials.get('model_version'),
|
||||
replicate_api_token=credentials.get('replicate_api_token')
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||
else:
|
||||
return ex
|
53
api/core/model_providers/models/entity/message.py
Normal file
53
api/core/model_providers/models/entity/message.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import enum
|
||||
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMRunResult(BaseModel):
|
||||
content: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
|
||||
class MessageType(enum.Enum):
|
||||
HUMAN = 'human'
|
||||
ASSISTANT = 'assistant'
|
||||
SYSTEM = 'system'
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
type: MessageType = MessageType.HUMAN
|
||||
content: str = ''
|
||||
|
||||
|
||||
def to_lc_messages(messages: list[PromptMessage]):
|
||||
lc_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
lc_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
lc_messages.append(AIMessage(content=message.content))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
lc_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
return lc_messages
|
||||
|
||||
|
||||
def to_prompt_messages(messages: list[BaseMessage]):
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
|
||||
elif isinstance(message, AIMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
|
||||
elif isinstance(message, SystemMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def str_to_prompt_messages(texts: list[str]):
|
||||
prompt_messages = []
|
||||
for text in texts:
|
||||
prompt_messages.append(PromptMessage(content=text))
|
||||
return prompt_messages
|
59
api/core/model_providers/models/entity/model_params.py
Normal file
59
api/core/model_providers/models/entity/model_params.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import enum
|
||||
from typing import Optional, TypeVar, Generic
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelMode(enum.Enum):
|
||||
COMPLETION = 'completion'
|
||||
CHAT = 'chat'
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
TEXT_GENERATION = 'text-generation'
|
||||
EMBEDDINGS = 'embeddings'
|
||||
SPEECH_TO_TEXT = 'speech2text'
|
||||
IMAGE = 'image'
|
||||
VIDEO = 'video'
|
||||
MODERATION = 'moderation'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ModelType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ModelKwargs(BaseModel):
|
||||
max_tokens: Optional[int]
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
presence_penalty: Optional[float]
|
||||
frequency_penalty: Optional[float]
|
||||
|
||||
|
||||
class KwargRuleType(enum.Enum):
|
||||
STRING = 'string'
|
||||
INTEGER = 'integer'
|
||||
FLOAT = 'float'
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class KwargRule(Generic[T], BaseModel):
|
||||
enabled: bool = True
|
||||
min: Optional[T] = None
|
||||
max: Optional[T] = None
|
||||
default: Optional[T] = None
|
||||
alias: Optional[str] = None
|
||||
|
||||
|
||||
class ModelKwargsRules(BaseModel):
|
||||
max_tokens: KwargRule = KwargRule[int](enabled=False)
|
||||
temperature: KwargRule = KwargRule[float](enabled=False)
|
||||
top_p: KwargRule = KwargRule[float](enabled=False)
|
||||
presence_penalty: KwargRule = KwargRule[float](enabled=False)
|
||||
frequency_penalty: KwargRule = KwargRule[float](enabled=False)
|
10
api/core/model_providers/models/entity/provider.py
Normal file
10
api/core/model_providers/models/entity/provider.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class ProviderQuotaUnit(Enum):
|
||||
TIMES = 'times'
|
||||
TOKENS = 'tokens'
|
||||
|
||||
|
||||
class ModelFeature(Enum):
|
||||
AGENT_THOUGHT = 'agent_thought'
|
0
api/core/model_providers/models/llm/__init__.py
Normal file
0
api/core/model_providers/models/llm/__init__.py
Normal file
107
api/core/model_providers/models/llm/anthropic_model.py
Normal file
107
api/core/model_providers/models/llm/anthropic_model.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
import decimal
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import anthropic
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class AnthropicModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatAnthropic(
|
||||
model=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
default_request_timeout=60,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': decimal.Decimal('1.63'),
|
||||
'completion': decimal.Decimal('5.51'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': decimal.Decimal('11.02'),
|
||||
'completion': decimal.Decimal('32.68'),
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1m * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, anthropic.APIConnectionError):
|
||||
logging.warning("Failed to connect to Anthropic API.")
|
||||
return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}")
|
||||
elif isinstance(ex, anthropic.RateLimitError):
|
||||
return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
|
||||
elif isinstance(ex, anthropic.AuthenticationError):
|
||||
return LLMAuthorizationError(f"Anthropic: {ex.message}")
|
||||
elif isinstance(ex, anthropic.BadRequestError):
|
||||
return LLMBadRequestError(f"Anthropic: {ex.message}")
|
||||
elif isinstance(ex, anthropic.APIStatusError):
|
||||
return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
|
177
api/core/model_providers/models/llm/azure_openai_model.py
Normal file
177
api/core/model_providers/models/llm/azure_openai_model.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
import decimal
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
|
||||
from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
class AzureOpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
if name == 'text-davinci-003':
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.name == 'text-davinci-003':
|
||||
client = EnhanceAzureOpenAI(
|
||||
deployment_name=self.name,
|
||||
streaming=self.streaming,
|
||||
request_timeout=60,
|
||||
openai_api_type='azure',
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
openai_api_key=self.credentials.get('openai_api_key'),
|
||||
openai_api_base=self.credentials.get('openai_api_base'),
|
||||
callbacks=self.callbacks,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
else:
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
client = EnhanceAzureChatOpenAI(
|
||||
deployment_name=self.name,
|
||||
temperature=provider_model_kwargs.get('temperature'),
|
||||
max_tokens=provider_model_kwargs.get('max_tokens'),
|
||||
model_kwargs=extra_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
request_timeout=60,
|
||||
openai_api_type='azure',
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
openai_api_key=self.credentials.get('openai_api_key'),
|
||||
openai_api_base=self.credentials.get('openai_api_base'),
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
if isinstance(prompts, str):
|
||||
return self._client.get_num_tokens(prompts)
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-35-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-35-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
base_model_name = self.credentials.get("base_model_name")
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[base_model_name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[base_model_name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name == 'text-davinci-003':
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
else:
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
self.client.temperature = provider_model_kwargs.get('temperature')
|
||||
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
|
||||
self.client.model_kwargs = extra_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to Azure OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to Azure OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("Azure OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
269
api/core/model_providers/models/llm/base.py
Normal file
269
api/core/model_providers/models/llm/base.py
Normal file
|
@ -0,0 +1,269 @@
|
|||
from abc import abstractmethod
|
||||
from typing import List, Optional, Any, Union
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class BaseLLM(BaseProviderModel):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
name: str
|
||||
model_kwargs: ModelKwargs
|
||||
credentials: dict
|
||||
streaming: bool = False
|
||||
type: ModelType = ModelType.TEXT_GENERATION
|
||||
deduct_quota: bool = True
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
self.name = name
|
||||
self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
|
||||
self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
|
||||
max_tokens=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
presence_penalty=None,
|
||||
frequency_penalty=None
|
||||
)
|
||||
self.credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
self.streaming = streaming
|
||||
|
||||
if streaming:
|
||||
default_callback = DifyStreamingStdOutCallbackHandler()
|
||||
else:
|
||||
default_callback = DifyStdOutCallbackHandler()
|
||||
|
||||
if not callbacks:
|
||||
callbacks = [default_callback]
|
||||
else:
|
||||
callbacks.append(default_callback)
|
||||
|
||||
self.callbacks = callbacks
|
||||
|
||||
client = self._init_client()
|
||||
super().__init__(model_provider, client)
|
||||
|
||||
@abstractmethod
|
||||
def _init_client(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMRunResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if self.deduct_quota:
|
||||
self.model_provider.check_quota_over_limit()
|
||||
|
||||
if not callbacks:
|
||||
callbacks = self.callbacks
|
||||
else:
|
||||
callbacks.extend(self.callbacks)
|
||||
|
||||
if 'fake_response' in kwargs and kwargs['fake_response']:
|
||||
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
|
||||
fake_llm = FakeLLM(
|
||||
response=kwargs['fake_response'],
|
||||
num_token_func=self.get_num_tokens,
|
||||
streaming=self.streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
result = fake_llm.generate([prompts])
|
||||
else:
|
||||
try:
|
||||
result = self._run(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
|
||||
**kwargs
|
||||
)
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
|
||||
if isinstance(result.generations[0][0], ChatGeneration):
|
||||
completion_content = result.generations[0][0].message.content
|
||||
else:
|
||||
completion_content = result.generations[0][0].text
|
||||
|
||||
if self.streaming and not self.support_streaming():
|
||||
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
|
||||
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
|
||||
fake_llm = FakeLLM(
|
||||
response=completion_content,
|
||||
num_token_func=self.get_num_tokens,
|
||||
streaming=self.streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
fake_llm.generate([prompts])
|
||||
|
||||
if result.llm_output and result.llm_output['token_usage']:
|
||||
prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
|
||||
completion_tokens = result.llm_output['token_usage']['completion_tokens']
|
||||
total_tokens = result.llm_output['token_usage']['total_tokens']
|
||||
else:
|
||||
prompt_tokens = self.get_num_tokens(messages)
|
||||
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
if self.deduct_quota:
|
||||
self.model_provider.deduct_quota(total_tokens)
|
||||
|
||||
return LLMRunResult(
|
||||
content=completion_content,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:param tokens:
|
||||
:param message_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_currency(self):
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_kwargs(self):
|
||||
return self.model_kwargs
|
||||
|
||||
def set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
self.model_kwargs = model_kwargs
|
||||
self._set_model_kwargs(model_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
"""
|
||||
Handle llm run exceptions.
|
||||
|
||||
:param ex:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_callbacks(self, callbacks: Callbacks):
|
||||
"""
|
||||
Add callbacks to client.
|
||||
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if not self.client.callbacks:
|
||||
self.client.callbacks = callbacks
|
||||
else:
|
||||
self.client.callbacks.extend(callbacks)
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
def _get_prompt_from_messages(self, messages: List[PromptMessage],
|
||||
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("prompt must not be empty.")
|
||||
|
||||
if not model_mode:
|
||||
model_mode = self.model_mode
|
||||
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
return messages[0].content
|
||||
else:
|
||||
chat_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
chat_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
chat_messages.append(AIMessage(content=message.content))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
chat_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
return chat_messages
|
||||
|
||||
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
|
||||
"""
|
||||
convert model kwargs to provider model kwargs.
|
||||
|
||||
:param model_rules:
|
||||
:param model_kwargs:
|
||||
:return:
|
||||
"""
|
||||
model_kwargs_input = {}
|
||||
for key, value in model_kwargs.dict().items():
|
||||
rule = getattr(model_rules, key)
|
||||
if not rule.enabled:
|
||||
continue
|
||||
|
||||
if rule.alias:
|
||||
key = rule.alias
|
||||
|
||||
if rule.default is not None and value is None:
|
||||
value = rule.default
|
||||
|
||||
if rule.min is not None:
|
||||
value = max(value, rule.min)
|
||||
|
||||
if rule.max is not None:
|
||||
value = min(value, rule.max)
|
||||
|
||||
model_kwargs_input[key] = value
|
||||
|
||||
return model_kwargs_input
|
70
api/core/model_providers/models/llm/chatglm_model.py
Normal file
70
api/core/model_providers/models/llm/chatglm_model.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import decimal
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import ChatGLM
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class ChatGLMModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatGLM(
|
||||
callbacks=self.callbacks,
|
||||
endpoint_url=self.credentials.get('api_base'),
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
82
api/core/model_providers/models/llm/huggingface_hub_model.py
Normal file
82
api/core/model_providers/models/llm/huggingface_hub_model.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain import HuggingFaceHub
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class HuggingfaceHubModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
client = HuggingFaceEndpoint(
|
||||
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
|
||||
task='text2text-generation',
|
||||
model_kwargs=provider_model_kwargs,
|
||||
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
else:
|
||||
client = HuggingFaceHub(
|
||||
repo_id=self.name,
|
||||
task=self.credentials['task_type'],
|
||||
model_kwargs=provider_model_kwargs,
|
||||
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# not support calc price
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.model_kwargs = provider_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
70
api/core/model_providers/models/llm/minimax_model.py
Normal file
70
api/core/model_providers/models/llm/minimax_model.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import decimal
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import Minimax
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class MinimaxModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return Minimax(
|
||||
model=self.name,
|
||||
model_kwargs={
|
||||
'stream': False
|
||||
},
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
else:
|
||||
return ex
|
219
api/core/model_providers/models/llm/openai_model.py
Normal file
219
api/core/model_providers/models/llm/openai_model.py
Normal file
|
@ -0,0 +1,219 @@
|
|||
import decimal
|
||||
import logging
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
COMPLETION_MODELS = [
|
||||
'text-davinci-003', # 4,097 tokens
|
||||
]
|
||||
|
||||
CHAT_MODELS = [
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
]
|
||||
|
||||
MODEL_MAX_TOKENS = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
|
||||
class OpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
if name in COMPLETION_MODELS:
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.name in COMPLETION_MODELS:
|
||||
client = EnhanceOpenAI(
|
||||
model_name=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
request_timeout=60,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
else:
|
||||
# Fine-tuning is currently only available for the following base models:
|
||||
# davinci, curie, babbage, and ada.
|
||||
# This means that except for the fixed `completion` model,
|
||||
# all other fine-tuned models are `completion` models.
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
client = EnhanceChatOpenAI(
|
||||
model_name=self.name,
|
||||
temperature=provider_model_kwargs.get('temperature'),
|
||||
max_tokens=provider_model_kwargs.get('max_tokens'),
|
||||
model_kwargs=extra_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
request_timeout=60,
|
||||
**self.credentials
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if self.name == 'gpt-4' \
|
||||
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
||||
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
if isinstance(prompts, str):
|
||||
return self._client.get_num_tokens(prompts)
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-3.5-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name in COMPLETION_MODELS:
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
else:
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
self.client.temperature = provider_model_kwargs.get('temperature')
|
||||
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
|
||||
self.client.model_kwargs = extra_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
|
||||
# def is_model_valid_or_raise(self):
|
||||
# """
|
||||
# check is a valid model.
|
||||
#
|
||||
# :return:
|
||||
# """
|
||||
# credentials = self._model_provider.get_credentials()
|
||||
#
|
||||
# try:
|
||||
# result = openai.Model.retrieve(
|
||||
# id=self.name,
|
||||
# api_key=credentials.get('openai_api_key'),
|
||||
# request_timeout=60
|
||||
# )
|
||||
#
|
||||
# if 'id' not in result or result['id'] != self.name:
|
||||
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.")
|
||||
# except openai.error.OpenAIError as e:
|
||||
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
# except Exception as e:
|
||||
# logging.exception("OpenAI Model retrieve failed.")
|
||||
# raise e
|
103
api/core/model_providers/models/llm/replicate_model.py
Normal file
103
api/core/model_providers/models/llm/replicate_model.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult, get_buffer_string
|
||||
from replicate.exceptions import ReplicateError, ModelError
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class ReplicateModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
return EnhanceReplicate(
|
||||
model=self.name + ':' + self.credentials.get('model_version'),
|
||||
input=provider_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
replicate_api_token=self.credentials.get('replicate_api_token'),
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
extra_kwargs = {}
|
||||
if isinstance(prompts, list):
|
||||
system_messages = [message for message in messages if message.type == 'system']
|
||||
if system_messages:
|
||||
system_message = system_messages[0]
|
||||
extra_kwargs['system_prompt'] = system_message.content
|
||||
prompts = [message for message in messages if message.type != 'system']
|
||||
|
||||
prompts = get_buffer_string(prompts)
|
||||
|
||||
# The maximum length the generated tokens can have.
|
||||
# Corresponds to the length of the input prompt + max_new_tokens.
|
||||
if 'max_length' in self._client.input:
|
||||
self._client.input['max_length'] = min(
|
||||
self._client.input['max_length'] + self.get_num_tokens(messages),
|
||||
self.model_rules.max_tokens.max
|
||||
)
|
||||
|
||||
return self._client.generate([prompts], stop, callbacks, **extra_kwargs)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
if isinstance(prompts, list):
|
||||
prompts = get_buffer_string(prompts)
|
||||
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.input = provider_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
73
api/core/model_providers/models/llm/spark_model.py
Normal file
73
api/core/model_providers/models/llm/spark_model.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.spark import ChatSpark
|
||||
from core.third_party.spark.spark_llm import SparkError
|
||||
|
||||
|
||||
class SparkModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatSpark(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
contents = [message.content for message in messages]
|
||||
return max(self._client.get_num_tokens("".join(contents)), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, SparkError):
|
||||
return LLMBadRequestError(f"Spark: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user