feat: server multi models support (#799)

This commit is contained in:
takatost 2023-08-12 00:57:00 +08:00 committed by GitHub
parent d8b712b325
commit 5fa2161b05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
213 changed files with 10556 additions and 2579 deletions

View File

@ -19,7 +19,8 @@ def check_file_for_chinese_comments(file_path):
def main():
has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py']
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py']
for root, _, files in os.walk("."):
for file in files:

View File

@ -102,3 +102,29 @@ NOTION_INTEGRATION_TYPE=public
NOTION_CLIENT_SECRET=you-client-secret
NOTION_CLIENT_ID=you-client-id
NOTION_INTERNAL_SECRET=you-internal-secret
# Hosted Model Credentials
HOSTED_OPENAI_ENABLED=false
HOSTED_OPENAI_API_KEY=
HOSTED_OPENAI_API_BASE=
HOSTED_OPENAI_API_ORGANIZATION=
HOSTED_OPENAI_QUOTA_LIMIT=200
HOSTED_OPENAI_PAID_ENABLED=false
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
HOSTED_AZURE_OPENAI_ENABLED=false
HOSTED_AZURE_OPENAI_API_KEY=
HOSTED_AZURE_OPENAI_API_BASE=
HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
HOSTED_ANTHROPIC_ENABLED=false
HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=

View File

@ -16,8 +16,9 @@ from flask import Flask, request, Response, session
import flask_login
from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail
ext_database, ext_storage, ext_mail, ext_stripe
from extensions.ext_database import db
from extensions.ext_login import login_manager
@ -71,7 +72,7 @@ def create_app(test_config=None) -> Flask:
register_blueprints(app)
register_commands(app)
core.init_app(app)
hosted.init_app(app)
return app
@ -88,6 +89,7 @@ def initialize_extensions(app):
ext_login.init_app(app)
ext_mail.init_app(app)
ext_sentry.init_app(app)
ext_stripe.init_app(app)
def _create_tenant_for_account(account):
@ -246,5 +248,18 @@ def threads():
}
@app.route('/db-pool-stat')
def pool_stat():
engine = db.engine
return {
'pool_size': engine.pool.size(),
'checked_in_connections': engine.pool.checkedin(),
'checked_out_connections': engine.pool.checkedout(),
'overflow_connections': engine.pool.overflow(),
'connection_timeout': engine.pool.timeout(),
'recycle_time': db.engine.pool._recycle
}
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5001)

View File

@ -1,5 +1,5 @@
import datetime
import logging
import math
import random
import string
import time
@ -9,18 +9,18 @@ from flask import current_app
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from core.model_providers.providers.hosted import hosted_model_providers
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
from models.dataset import Dataset, DatasetQuery, Document
from models.model import Account
import secrets
import base64
from models.provider import Provider, ProviderName
from services.provider_service import ProviderService
from models.provider import Provider, ProviderType, ProviderQuotaType
@click.command('reset-password', help='Reset the account password.')
@ -251,26 +251,37 @@ def clean_unused_dataset_indexes():
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
def sync_anthropic_hosted_providers():
if not hosted_model_providers.anthropic:
click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
return
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0
page = 1
while True:
try:
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
providers = db.session.query(Provider).filter(
Provider.provider_name == 'anthropic',
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
except NotFound:
break
page += 1
for tenant in tenants:
for provider in providers:
try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
original_quota_limit = provider.quota_limit
new_quota_limit = hosted_model_providers.anthropic.quota_limit
division = math.ceil(new_quota_limit / 1000)
provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
else original_quota_limit * division
provider.quota_used = division * provider.quota_used
db.session.commit()
count += 1
except Exception as e:
click.echo(click.style(

View File

@ -41,6 +41,7 @@ DEFAULTS = {
'SESSION_USE_SIGNER': 'True',
'DEPLOY_ENV': 'PRODUCTION',
'SQLALCHEMY_POOL_SIZE': 30,
'SQLALCHEMY_POOL_RECYCLE': 3600,
'SQLALCHEMY_ECHO': 'False',
'SENTRY_TRACES_SAMPLE_RATE': 1.0,
'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
@ -50,9 +51,16 @@ DEFAULTS = {
'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai',
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30
}
@ -182,7 +190,10 @@ class Config:
}
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
self.SQLALCHEMY_ENGINE_OPTIONS = {'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE'))}
self.SQLALCHEMY_ENGINE_OPTIONS = {
'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
}
self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
@ -194,20 +205,35 @@ class Config:
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
# hosted provider credentials
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
# By default it is False
# You could disable it for compatibility with certain OpenAPI providers
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')

View File

@ -18,10 +18,13 @@ from .auth import login, oauth, data_source_oauth, activate
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
# Import workspace controllers
from .workspace import workspace, members, model_providers, account, tool_providers
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
# Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
# Import universal chat controllers
from .universal_chat import chat, conversation, message, parameter, audio
# Import webhook controllers
from .webhook import stripe

View File

@ -2,16 +2,17 @@
import json
from datetime import datetime
import flask
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
from werkzeug.exceptions import Unauthorized, Forbidden
from werkzeug.exceptions import Forbidden
from constants.model_template import model_templates, demo_model_templates
from controllers.console import api
from controllers.console.app.error import AppNotFoundError
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted
from libs.helper import TimestampField
from extensions.ext_database import db
@ -126,9 +127,9 @@ class AppListApi(Resource):
if args['model_config'] is not None:
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=args['model_config'],
mode=args['mode']
config=args['model_config']
)
app = App(
@ -164,6 +165,21 @@ class AppListApi(Resource):
app = App(**model_config_template['app'])
app_model_config = AppModelConfig(**model_config_template['model_config'])
default_model = ModelFactory.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_GENERATION
)
if default_model:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.provider_name
model_dict['name'] = default_model.model_name
app_model_config.model = json.dumps(model_dict)
else:
raise ProviderNotInitializeError(
f"No Text Generation Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
app.name = args['name']
app.mode = args['mode']
app.icon = args['icon']

View File

@ -14,7 +14,7 @@ from controllers.console.app.error import AppUnavailableError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from flask_restful import Resource
from services.audio_service import AudioService

View File

@ -17,7 +17,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from flask_restful import Resource, reqparse
@ -41,8 +41,11 @@ class CompletionMessageApi(Resource):
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'
account = flask_login.current_user
try:
@ -51,7 +54,7 @@ class CompletionMessageApi(Resource):
user=account,
args=args,
from_source='console',
streaming=True,
streaming=streaming,
is_model_config_override=True
)
@ -111,8 +114,11 @@ class ChatMessageApi(Resource):
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'
account = flask_login.current_user
try:
@ -121,7 +127,7 @@ class ChatMessageApi(Resource):
user=account,
args=args,
from_source='console',
streaming=True,
streaming=streaming,
is_model_config_override=True
)

View File

@ -7,7 +7,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.generator.llm_generator import LLMGenerator
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError

View File

@ -14,7 +14,7 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from libs.infinite_scroll_pagination import InfiniteScrollPagination

View File

@ -28,9 +28,9 @@ class ModelConfigResource(Resource):
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=request.json,
mode=app_model.mode
config=request.json
)
new_app_model_config = AppModelConfig(

View File

@ -255,7 +255,7 @@ class DataSourceNotionApi(Resource):
# validate args
DocumentService.estimate_args_validate(args)
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule'])
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
return response, 200

View File

@ -5,10 +5,13 @@ from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from libs.helper import TimestampField
from extensions.ext_database import db
from models.dataset import DocumentSegment, Document
@ -97,6 +100,15 @@ class DatasetListApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
@ -235,12 +247,26 @@ class DatasetIndexingEstimateApi(Resource):
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form'])
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'])
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
else:
raise ValueError('Data source type not support')
return response, 200

View File

@ -18,7 +18,9 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from libs.helper import TimestampField
from extensions.ext_database import db
@ -280,6 +282,15 @@ class DatasetDocumentListApi(Resource):
# validate args
DocumentService.document_create_args_validate(args)
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError as ex:
@ -319,6 +330,15 @@ class DatasetInitApi(Resource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
args = parser.parse_args()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
# validate args
DocumentService.document_create_args_validate(args)
@ -384,7 +404,13 @@ class DocumentIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict)
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
return response
@ -445,12 +471,24 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict)
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
elif dataset.data_source_type:
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(info_list,
data_process_rule_dict)
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
info_list,
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
else:
raise ValueError('Data source type not support')
return response

View File

@ -11,7 +11,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import TimestampField
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@ -102,6 +102,8 @@ class HitTestingApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))

View File

@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.explore.wraps import InstalledAppResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

View File

@ -15,7 +15,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from services.completion_service import CompletionService

View File

@ -15,7 +15,7 @@ from controllers.console.app.error import AppMoreLikeThisDisabledError, Provider
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.explore.wraps import InstalledAppResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService

View File

@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import InstalledApp
@ -35,13 +33,12 @@ class AppParameterApi(InstalledAppResource):
"""Retrieve app parameters."""
app_model = installed_app.app
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

View File

@ -12,9 +12,8 @@ from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.constant import llm_constant
from core.conversation_message_task import PubHandler
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
from libs.helper import uuid_value
from services.completion_service import CompletionService
@ -27,6 +26,7 @@ class UniversalChatApi(UniversalChatResource):
parser = reqparse.RequestParser()
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('provider', type=str, required=True, location='json')
parser.add_argument('model', type=str, required=True, location='json')
parser.add_argument('tools', type=list, required=True, location='json')
args = parser.parse_args()
@ -36,11 +36,7 @@ class UniversalChatApi(UniversalChatResource):
# update app model config
args['model_config'] = app_model_config.to_dict()
args['model_config']['model']['name'] = args['model']
if not llm_constant.models[args['model']]:
raise ValueError("Model not exists.")
args['model_config']['model']['provider'] = llm_constant.models[args['model']]
args['model_config']['model']['provider'] = args['provider']
args['model_config']['agent_mode']['tools'] = args['tools']
if not args['model_config']['agent_mode']['tools']:

View File

@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from services.errors.conversation import ConversationNotExistsError

View File

@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
from controllers.console import api
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
@ -23,13 +21,12 @@ class UniversalChatParameterApi(UniversalChatResource):
"""Retrieve app parameters."""
app_model = universal_app
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'speech_to_text': app_model_config.speech_to_text_dict,
}

View File

@ -0,0 +1,53 @@
import logging
import stripe
from flask import request, current_app
from flask_restful import Resource
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import only_edition_cloud
from services.provider_checkout_service import ProviderCheckoutService
class StripeWebhookApi(Resource):
@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400
# Handle the checkout.session.completed event
if event['type'] == 'checkout.session.completed':
logging.debug(event['data']['object']['id'])
logging.debug(event['data']['object']['amount_subtotal'])
logging.debug(event['data']['object']['currency'])
logging.debug(event['data']['object']['payment_intent'])
logging.debug(event['data']['object']['payment_status'])
logging.debug(event['data']['object']['metadata'])
# Fulfill the purchase...
provider_checkout_service = ProviderCheckoutService()
try:
provider_checkout_service.fulfill_provider_order(event)
except Exception as e:
logging.debug(str(e))
return 'success', 200
return 'success', 200
api.add_resource(StripeWebhookApi, '/webhook/stripe')

View File

@ -1,24 +1,18 @@
# -*- coding:utf-8 -*-
import base64
import json
import logging
from flask import current_app
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, abort
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.provider.errors import ValidateFailedError
from extensions.ext_database import db
from libs import rsa
from models.provider import Provider, ProviderType, ProviderName
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import CredentialsValidateFailedError
from services.provider_checkout_service import ProviderCheckoutService
from services.provider_service import ProviderService
class ProviderListApi(Resource):
class ModelProviderListApi(Resource):
@setup_required
@login_required
@ -26,156 +20,36 @@ class ProviderListApi(Resource):
def get(self):
tenant_id = current_user.current_tenant_id
"""
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
rest is replaced by * and the last two bits are displayed in plaintext
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
ProviderService.init_supported_provider(current_user.current_tenant)
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
provider_list = [
{
'provider_name': p.provider_name,
'provider_type': p.provider_type,
'is_valid': p.is_valid,
'last_used': p.last_used,
'is_enabled': p.is_enabled,
**({
'quota_type': p.quota_type,
'quota_limit': p.quota_limit,
'quota_used': p.quota_used
} if p.provider_type == ProviderType.SYSTEM.value else {}),
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
ProviderName(p.provider_name), only_custom=True)
if p.provider_type == ProviderType.CUSTOM.value else None
}
for p in providers
]
provider_service = ProviderService()
provider_list = provider_service.get_provider_list(tenant_id)
return provider_list
class ProviderTokenApi(Resource):
class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if provider not in [p.value for p in ProviderName]:
abort(404)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
logging.log(logging.ERROR,
f'User {current_user.id} is not authorized to update provider token, current_role is {current_user.current_tenant.current_role}')
raise Forbidden()
def post(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('token', type=ProviderService.get_token_type(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider)
), required=True, nullable=False, location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
if args['token']:
try:
ProviderService.validate_provider_configs(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)
token_is_valid = True
except ValidateFailedError as ex:
raise ValueError(str(ex))
base64_encrypted_token = ProviderService.get_encrypted_token(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)
else:
base64_encrypted_token = None
token_is_valid = False
tenant = current_user.current_tenant
provider_model = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name == provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
# Only allow updating token for CUSTOM provider type
if provider_model:
provider_model.encrypted_config = base64_encrypted_token
provider_model.is_valid = token_is_valid
else:
provider_model = Provider(tenant_id=tenant.id, provider_name=provider,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=base64_encrypted_token,
is_valid=token_is_valid)
db.session.add(provider_model)
if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
other_providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
Provider.provider_name != provider,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
for other_provider in other_providers:
other_provider.is_valid = False
db.session.commit()
if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
return {'result': 'success'}, 201
class ProviderTokenValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if provider not in [p.value for p in ProviderName]:
abort(404)
parser = reqparse.RequestParser()
parser.add_argument('token', type=ProviderService.get_token_type(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider)
), required=True, nullable=False, location='json')
args = parser.parse_args()
# todo: remove this when the provider is supported
if provider in [ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
provider_service = ProviderService()
result = True
error = None
try:
ProviderService.validate_provider_configs(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
provider_service.custom_provider_config_validate(
provider_name=provider_name,
config=args['config']
)
except ValidateFailedError as e:
except CredentialsValidateFailedError as ex:
result = False
error = str(e)
error = str(ex)
response = {'result': 'success' if result else 'error'}
@ -185,91 +59,227 @@ class ProviderTokenValidateApi(Resource):
return response
class ProviderSystemApi(Resource):
class ModelProviderUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def put(self, provider):
if provider not in [p.value for p in ProviderName]:
abort(404)
parser = reqparse.RequestParser()
parser.add_argument('is_enabled', type=bool, required=True, location='json')
args = parser.parse_args()
tenant = current_user.current_tenant_id
provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider).first()
if provider_model and provider_model.provider_type == ProviderType.SYSTEM.value:
provider_model.is_valid = args['is_enabled']
db.session.commit()
elif not provider_model:
if provider == ProviderName.OPENAI.value:
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
elif provider == ProviderName.ANTHROPIC.value:
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
else:
quota_limit = 0
ProviderService.create_system_provider(
tenant,
provider,
quota_limit,
args['is_enabled']
)
else:
abort(403)
return {'result': 'success'}
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
if provider not in [p.value for p in ProviderName]:
abort(404)
# The role of the current user in the ta table must be admin or owner
def post(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
provider_model = db.session.query(Provider).filter(Provider.tenant_id == current_user.current_tenant_id,
Provider.provider_name == provider,
Provider.provider_type == ProviderType.SYSTEM.value).first()
parser = reqparse.RequestParser()
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
system_model = None
if provider_model:
system_model = {
'result': 'success',
'provider': {
'provider_name': provider_model.provider_name,
'provider_type': provider_model.provider_type,
'is_valid': provider_model.is_valid,
'last_used': provider_model.last_used,
'is_enabled': provider_model.is_enabled,
'quota_type': provider_model.quota_type,
'quota_limit': provider_model.quota_limit,
'quota_used': provider_model.quota_used
}
provider_service = ProviderService()
try:
provider_service.save_custom_provider_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
config=args['config']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 201
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
provider_service = ProviderService()
provider_service.delete_custom_provider(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name
)
return {'result': 'success'}, 204
class ModelProviderModelValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
result = True
error = None
try:
provider_service.custom_provider_model_config_validate(
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type'],
config=args['config']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
class ModelProviderModelUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
try:
provider_service.add_or_save_custom_provider_model_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type'],
config=args['config']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.delete_custom_provider_model(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type']
)
return {'result': 'success'}, 204
class PreferredProviderTypeUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
choices=['system', 'custom'], location='json')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.switch_preferred_provider(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
preferred_provider_type=args['preferred_provider_type']
)
return {'result': 'success'}
class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
provider_service = ProviderService()
try:
parameter_rules = provider_service.get_model_parameter_rules(
tenant_id=current_user.current_tenant_id,
model_provider_name=provider_name,
model_name=args['model_name'],
model_type='text-generation'
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"Current Text Generation Model is invalid. Please switch to the available model.")
rules = {
k: {
'enabled': v.enabled,
'min': v.min,
'max': v.max,
'default': v.default
}
else:
abort(404)
for k, v in vars(parameter_rules).items()
}
return system_model
return rules
api.add_resource(ProviderTokenApi, '/providers/<provider>/token',
endpoint='current_providers_token') # Deprecated
api.add_resource(ProviderTokenValidateApi, '/providers/<provider>/token-validate',
endpoint='current_providers_token_validate') # Deprecated
class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
provider_service = ProviderCheckoutService()
provider_checkout = provider_service.create_checkout(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
account=current_user
)
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
endpoint='workspaces_current_providers_token') # PUT for updating provider token
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
return {
'url': provider_checkout.get_checkout_url()
}
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
api.add_resource(ProviderSystemApi, '/workspaces/current/providers/<provider>/system',
endpoint='workspaces_current_providers_system') # GET for getting provider quota, PUT for updating provider status
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
api.add_resource(ModelProviderModelValidateApi,
'/workspaces/current/model-providers/<string:provider_name>/models/validate')
api.add_resource(ModelProviderModelUpdateApi,
'/workspaces/current/model-providers/<string:provider_name>/models')
api.add_resource(PreferredProviderTypeUpdateApi,
'/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')

View File

@ -0,0 +1,108 @@
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from models.provider import ProviderType
from services.provider_service import ProviderService
class DefaultModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
provider_service = ProviderService()
default_model = provider_service.get_default_model_of_model_type(
tenant_id=tenant_id,
model_type=args['model_type']
)
if not default_model:
return None
model_provider = ModelProviderFactory.get_preferred_model_provider(
tenant_id,
default_model.provider_name
)
if not model_provider:
return {
'model_name': default_model.model_name,
'model_type': default_model.model_type,
'model_provider': {
'provider_name': default_model.provider_name
}
}
provider = model_provider.provider
rst = {
'model_name': default_model.model_name,
'model_type': default_model.model_type,
'model_provider': {
'provider_name': provider.provider_name,
'provider_type': provider.provider_type
}
}
model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
if provider.provider_type == ProviderType.SYSTEM.value:
rst['model_provider']['quota_type'] = provider.quota_type
rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
rst['model_provider']['quota_limit'] = provider.quota_limit
rst['model_provider']['quota_used'] = provider.quota_used
return rst
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.update_default_model_of_model_type(
tenant_id=current_user.current_tenant_id,
model_type=args['model_type'],
provider_name=args['provider_name'],
model_name=args['model_name']
)
return {'result': 'success'}
class ValidModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, model_type):
ModelType.value_of(model_type)
provider_service = ProviderService()
valid_models = provider_service.get_valid_model_list(
tenant_id=current_user.current_tenant_id,
model_type=model_type
)
return valid_models
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')

View File

@ -0,0 +1,130 @@
# -*- coding:utf-8 -*-
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType
from services.provider_service import ProviderService
class ProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
"""
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
rest is replaced by * and the last two bits are displayed in plaintext
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
provider_service = ProviderService()
provider_info_list = provider_service.get_provider_list(tenant_id)
provider_list = [
{
'provider_name': p['provider_name'],
'provider_type': p['provider_type'],
'is_valid': p['is_valid'],
'last_used': p['last_used'],
'is_enabled': p['is_valid'],
**({
'quota_type': p['quota_type'],
'quota_limit': p['quota_limit'],
'quota_used': p['quota_used']
} if p['provider_type'] == ProviderType.SYSTEM.value else {}),
'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
if p['config'] else None
}
for name, provider_info in provider_info_list.items()
for p in provider_info['providers']
]
return provider_list
class ProviderTokenApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
provider_service = ProviderService()
try:
provider_service.save_custom_provider_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 201
class ProviderTokenValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
result = True
error = None
try:
provider_service.custom_provider_config_validate(
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
endpoint='workspaces_current_providers_token') # PUT for updating provider token
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list

View File

@ -30,7 +30,7 @@ tenant_fields = {
'created_at': TimestampField,
'role': fields.String,
'providers': fields.List(fields.Nested(provider_fields)),
'in_trail': fields.Boolean,
'in_trial': fields.Boolean,
'trial_end_reason': fields.String,
}

View File

@ -4,8 +4,6 @@ from flask_restful import fields, marshal_with
from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
@ -35,13 +33,12 @@ class AppParameterApi(AppApiResource):
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@ -9,7 +9,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
ProviderNotSupportSpeechToTextError
from controllers.service_api.wraps import AppApiResource
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from models.model import App, AppModelConfig
from services.audio_service import AudioService

View File

@ -14,7 +14,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ProviderModelCurrentlyNotSupportError
from controllers.service_api.wraps import AppApiResource
from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from services.completion_service import CompletionService

View File

@ -11,7 +11,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
DatasetNotInitedError
from controllers.service_api.wraps import DatasetApiResource
from core.llm.error import ProviderTokenNotInitError
from core.model_providers.error import ProviderTokenNotInitError
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import UploadFile

View File

@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
from controllers.web import api
from controllers.web.wraps import WebApiResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
@ -34,13 +32,12 @@ class AppParameterApi(WebApiResource):
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@ -10,7 +10,7 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.web.wraps import WebApiResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \

View File

@ -14,7 +14,7 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource
from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from services.completion_service import CompletionService

View File

@ -14,7 +14,7 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService

View File

@ -1,36 +0,0 @@
import os
from typing import Optional
import langchain
from flask import Flask
from pydantic import BaseModel
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.prompt.prompt_template import OneLineFormatter
class HostedOpenAICredential(BaseModel):
api_key: str
class HostedAnthropicCredential(BaseModel):
api_key: str
class HostedLLMCredentials(BaseModel):
openai: Optional[HostedOpenAICredential] = None
anthropic: Optional[HostedAnthropicCredential] = None
hosted_llm_credentials = HostedLLMCredentials()
def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
if app.config.get("ANTHROPIC_API_KEY"):
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))

View File

@ -1,20 +1,17 @@
from typing import cast, List
from typing import List
from langchain import OpenAI
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import BaseMessage
from core.constant import llm_constant
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
class CalcTokenMixin:
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
llm = cast(ChatOpenAI, llm)
return llm.get_num_tokens_from_messages(messages)
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
return model_instance.get_num_tokens(to_prompt_messages(messages))
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
@ -22,10 +19,9 @@ class CalcTokenMixin:
:param messages:
:return:
"""
llm = cast(ChatOpenAI, llm)
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
completion_max_tokens = llm.max_tokens
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
llm_max_tokens = model_instance.model_rules.max_tokens.max
completion_max_tokens = model_instance.model_kwargs.max_tokens
used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
return rest_tokens

View File

@ -4,9 +4,11 @@ from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@ -14,6 +16,12 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""

View File

@ -6,7 +6,8 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@ -84,7 +85,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))

View File

@ -3,20 +3,28 @@ from typing import cast, List
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from pydantic import BaseModel
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
model_instance: BaseLLM
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages

View File

@ -6,7 +6,8 @@ from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFuncti
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@ -84,7 +85,7 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))

View File

@ -0,0 +1,162 @@
import re
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
model_instance: BaseLLM
dataset_tools: Sequence[BaseTool]
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.dataset_tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
try:
return self.output_parser.parse(full_output)
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
unique_tool_names = set(tool.name for tool in tools)
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
dataset_tools=tools,
**kwargs,
)

View File

@ -14,7 +14,7 @@ from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.model_providers.models.llm.base import BaseLLM
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@ -53,6 +53,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
@ -89,7 +95,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)

View File

@ -3,7 +3,6 @@ import logging
from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool
@ -13,14 +12,17 @@ from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'
@ -28,10 +30,9 @@ class PlanningStrategy(str, enum.Enum):
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
llm: BaseLanguageModel
model_instance: BaseLLM
tools: list[BaseTool]
summary_llm: BaseLanguageModel
dataset_llm: BaseLanguageModel
summary_model_instance: BaseLLM
memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
@ -60,36 +61,49 @@ class AgentExecutor:
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
llm=self.configuration.dataset_llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True
)
else:
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")

View File

@ -10,15 +10,16 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.models.llm.base import BaseLLM
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.model_name = model_name
self.model_instant = model_instant
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
@ -152,7 +153,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
self._message_agent_thought, self.model_instant, self._current_loop
)
self._agent_loops.append(self._current_loop)
@ -183,7 +184,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
self._message_agent_thought, self.model_instant, self._current_loop
)
self._agent_loops.append(self._current_loop)

View File

@ -3,18 +3,20 @@ import time
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
from langchain.schema import LLMResult, BaseMessage
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
from core.model_providers.models.llm.base import BaseLLM
class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, llm: BaseLanguageModel,
def __init__(self, model_instance: BaseLLM,
conversation_message_task: ConversationMessageTask):
self.llm = llm
self.model_instance = model_instance
self.llm_message = LLMMessage()
self.start_at = None
self.conversation_message_task = conversation_message_task
@ -46,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@ -58,7 +60,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
"text": prompts[0]
}]
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
@ -68,7 +70,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message)
@ -89,7 +91,9 @@ class LLMCallbackHandler(BaseCallbackHandler):
if self.conversation_message_task.streaming:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
else:
logging.error(error)

View File

@ -5,9 +5,7 @@ from typing import Any, Dict, Union
from langchain.callbacks.base import BaseCallbackHandler
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.entity.chain_result import ChainResult
from core.constant import llm_constant
from core.conversation_message_task import ConversationMessageTask

View File

@ -2,27 +2,19 @@ import logging
import re
from typing import Optional, List, Union, Tuple
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, HumanMessage
from langchain.schema import BaseMessage
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError
from core.llm.fake import FakeLLM
from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
@ -51,12 +43,10 @@ class Completion:
inputs = conversation.inputs
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
app_model_config=app_model_config,
query=query,
inputs=inputs
model_config=app_model_config.model_dict,
streaming=streaming
)
conversation_message_task = ConversationMessageTask(
@ -68,10 +58,17 @@ class Completion:
is_override=is_override,
inputs=inputs,
query=query,
streaming=streaming
streaming=streaming,
model_instance=final_model_instance
)
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
model_instance=final_model_instance,
app_model_config=app_model_config,
query=query,
inputs=inputs
)
# init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser(
@ -80,6 +77,7 @@ class Completion:
)
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)
@ -102,15 +100,14 @@ class Completion:
# run the final llm
try:
cls.run_final_llm(
tenant_id=app.tenant_id,
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory,
streaming=streaming
memory=memory
)
except ConversationTaskStoppedException:
return
@ -121,31 +118,20 @@ class Completion:
return
@classmethod
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
final_llm = FakeLLM(response=agent_execute_result.output,
origin_llm=agent_execute_result.configuration.llm,
streaming=streaming)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
response = final_llm.generate([[HumanMessage(content=query)]])
return response
final_llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict,
streaming=streaming
)
fake_response = agent_execute_result.output
# get llm prompt
prompt, stop_words = cls.get_main_llm_prompt(
prompt_messages, stop_words = cls.get_main_llm_prompt(
mode=mode,
llm=final_llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
@ -154,25 +140,26 @@ class Completion:
memory=memory
)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=final_llm,
model=app_model_config.model_dict,
prompt=prompt,
mode=mode
model_instance=model_instance,
prompt_messages=prompt_messages,
)
response = final_llm.generate([prompt], stop_words)
response = model_instance.run(
messages=prompt_messages,
stop=stop_words,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
return response
@classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
def get_main_llm_prompt(cls, mode: str, model: dict,
pre_prompt: str, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
Tuple[List[PromptMessage], Optional[List[str]]]:
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
@ -200,11 +187,7 @@ And answer according to the language of the user's question.
**prompt_inputs
)
if isinstance(llm, BaseChatModel):
# use chat llm as completion model
return [HumanMessage(content=prompt_content)], None
else:
return prompt_content, None
return [PromptMessage(content=prompt_content)], None
else:
messages: List[BaseMessage] = []
@ -249,12 +232,14 @@ And answer according to the language of the user's question.
inputs=human_inputs
)
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
model_name = model['name']
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = llm_constant.max_context_token_length[model_name] \
- max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
if memory.model_instance.model_rules.max_tokens.max:
curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
@ -274,17 +259,7 @@ And answer according to the language of the user's question.
for message in messages:
message.content = re.sub(r'<\|.*?\|>', '', message.content)
return messages, ['\nHuman:', '</histories>']
@classmethod
def get_llm_callbacks(cls, llm: BaseLanguageModel,
streaming: bool,
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming:
return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
else:
return [llm_callback_handler, DifyStdOutCallbackHandler()]
return to_prompt_messages(messages), ['\nHuman:', '</histories>']
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
@ -300,15 +275,15 @@ And answer according to the language of the user's question.
conversation: Conversation,
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
# only for calc token in memory
memory_llm = LLMBuilder.to_llm_from_model(
memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=tenant_id,
model=app_model_config.model_dict
model_config=app_model_config.model_dict
)
# use llm config from conversation
memory = ReadOnlyConversationTokenDBBufferSharedMemory(
conversation=conversation,
llm=memory_llm,
model_instance=memory_model_instance,
max_token_limit=kwargs.get("max_token_limit", 2048),
memory_key=kwargs.get("memory_key", "chat_history"),
return_messages=kwargs.get("return_messages", True),
@ -320,21 +295,20 @@ And answer according to the language of the user's question.
return memory
@classmethod
def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
query: str, inputs: dict) -> int:
llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict
)
model_limited_tokens = model_instance.model_rules.max_tokens.max
max_tokens = model_instance.get_model_kwargs().max_tokens
model_name = app_model_config.model_dict.get("name")
model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
if model_limited_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
# get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt(
prompt_messages, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
@ -343,9 +317,7 @@ And answer according to the language of the user's question.
memory=None
)
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
else llm.get_num_tokens_from_messages(prompt)
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
@ -354,36 +326,40 @@ And answer according to the language of the user's question.
return rest_tokens
@classmethod
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
prompt: Union[str, List[BaseMessage]], mode: str):
def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_name = model.get("name")
model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = model.get("completion_params").get('max_tokens')
model_limited_tokens = model_instance.model_rules.max_tokens.max
max_tokens = model_instance.get_model_kwargs().max_tokens
if mode == 'completion' and isinstance(final_llm, BaseLLM):
prompt_tokens = final_llm.get_num_tokens(prompt)
else:
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
if model_limited_tokens is None:
return
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
final_llm.max_tokens = max_tokens
# update model instance max tokens
model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs)
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
llm = LLMBuilder.to_llm_from_model(
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
model=app_model_config.model_dict,
model_config=app_model_config.model_dict,
streaming=streaming
)
# get llm prompt
original_prompt, _ = cls.get_main_llm_prompt(
old_prompt_messages, _ = cls.get_main_llm_prompt(
mode="completion",
llm=llm,
model=app_model_config.model_dict,
pre_prompt=pre_prompt,
query=message.query,
@ -395,10 +371,9 @@ And answer according to the language of the user's question.
original_completion = message.answer.strip()
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
prompt_messages = [PromptMessage(content=prompt)]
conversation_message_task = ConversationMessageTask(
task_id=task_id,
@ -408,16 +383,16 @@ And answer according to the language of the user's question.
inputs=message.inputs,
query=message.query,
is_override=True if message.override_model_configs else False,
streaming=streaming
streaming=streaming,
model_instance=final_model_instance
)
llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=llm,
model=app_model_config.model_dict,
prompt=prompt,
mode='completion'
model_instance=final_model_instance,
prompt_messages=prompt_messages
)
llm.generate([prompt])
final_model_instance.run(
messages=prompt_messages,
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
)

View File

@ -1,109 +0,0 @@
from _decimal import Decimal
models = {
'claude-instant-1': 'anthropic', # 100,000 tokens
'claude-2': 'anthropic', # 100,000 tokens
'gpt-4': 'openai', # 8,192 tokens
'gpt-4-32k': 'openai', # 32,768 tokens
'gpt-3.5-turbo': 'openai', # 4,096 tokens
'gpt-3.5-turbo-16k': 'openai', # 16384 tokens
'text-davinci-003': 'openai', # 4,097 tokens
'text-davinci-002': 'openai', # 4,097 tokens
'text-curie-001': 'openai', # 2,049 tokens
'text-babbage-001': 'openai', # 2,049 tokens
'text-ada-001': 'openai', # 2,049 tokens
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
'whisper-1': 'openai'
}
max_context_token_length = {
'claude-instant-1': 100000,
'claude-2': 100000,
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
'text-davinci-002': 4097,
'text-curie-001': 2049,
'text-babbage-001': 2049,
'text-ada-001': 2049,
'text-embedding-ada-002': 8191,
}
models_by_mode = {
'chat': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
],
'completion': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
'text-davinci-003', # 4,097 tokens
'text-davinci-002' # 4,097 tokens
'text-curie-001', # 2,049 tokens
'text-babbage-001', # 2,049 tokens
'text-ada-001' # 2,049 tokens
],
'embedding': [
'text-embedding-ada-002' # 8191 tokens, 1536 dimensions
]
}
model_currency = 'USD'
model_prices = {
'claude-instant-1': {
'prompt': Decimal('0.00163'),
'completion': Decimal('0.00551'),
},
'claude-2': {
'prompt': Decimal('0.01102'),
'completion': Decimal('0.03268'),
},
'gpt-4': {
'prompt': Decimal('0.03'),
'completion': Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': Decimal('0.06'),
'completion': Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': Decimal('0.0015'),
'completion': Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': Decimal('0.003'),
'completion': Decimal('0.004')
},
'text-davinci-003': {
'prompt': Decimal('0.02'),
'completion': Decimal('0.02')
},
'text-curie-001': {
'prompt': Decimal('0.002'),
'completion': Decimal('0.002')
},
'text-babbage-001': {
'prompt': Decimal('0.0005'),
'completion': Decimal('0.0005')
},
'text-ada-001': {
'prompt': Decimal('0.0004'),
'completion': Decimal('0.0004')
},
'text-embedding-ada-002': {
'usage': Decimal('0.0001'),
}
}
agent_model_name = 'text-davinci-003'

View File

@ -6,9 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage
from core.callback_handler.entity.chain_result import ChainResult
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
from core.llm.provider.llm_provider_service import LLMProviderService
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from events.message_event import message_was_created
@ -16,12 +16,11 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
from models.provider import ProviderType, Provider
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
self.task_id = task_id
@ -38,9 +37,12 @@ class ConversationMessageTask:
self.conversation = conversation
self.is_new_conversation = False
self.model_instance = model_instance
self.message = None
self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider')
self.model_name = self.model_dict.get('name')
self.mode = app.mode
@ -56,9 +58,6 @@ class ConversationMessageTask:
)
def init(self):
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
self.model_dict['provider'] = provider_name
override_model_configs = None
if self.is_override:
override_model_configs = {
@ -89,15 +88,19 @@ class ConversationMessageTask:
if self.app_model_config.pre_prompt:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_provider_name=self.provider_name,
model_name=self.model_name
)
system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
if not self.conversation:
self.is_new_conversation = True
self.conversation = Conversation(
app_id=self.app_model_config.app_id,
app_model_config_id=self.app_model_config.id,
model_provider=self.model_dict.get('provider'),
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=self.mode,
@ -117,7 +120,7 @@ class ConversationMessageTask:
self.message = Message(
app_id=self.app_model_config.app_id,
model_provider=self.model_dict.get('provider'),
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=self.conversation.id,
@ -131,7 +134,7 @@ class ConversationMessageTask:
answer_unit_price=0,
provider_response_latency=0,
total_price=0,
currency=llm_constant.model_currency,
currency=self.model_instance.get_currency(),
from_source=('console' if isinstance(self.user, Account) else 'api'),
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
@ -145,12 +148,10 @@ class ConversationMessageTask:
self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
model_name = self.app_model_config.model_dict.get('name')
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = llm_constant.model_prices[model_name]['prompt']
answer_unit_price = llm_constant.model_prices[model_name]['completion']
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
@ -163,8 +164,6 @@ class ConversationMessageTask:
self.message.provider_response_latency = llm_message.latency
self.message.total_price = total_price
self.update_provider_quota()
db.session.commit()
message_was_created.send(
@ -176,20 +175,6 @@ class ConversationMessageTask:
if not by_stopped:
self.end()
def update_provider_quota(self):
llm_provider_service = LLMProviderService(
tenant_id=self.app.tenant_id,
provider_name=self.message.model_provider,
)
provider = llm_provider_service.get_provider_db_record()
if provider and provider.provider_type == ProviderType.SYSTEM.value:
db.session.query(Provider).filter(
Provider.tenant_id == self.app.tenant_id,
Provider.provider_name == provider.provider_name,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + 1})
def init_chain(self, chain_result: ChainResult):
message_chain = MessageChain(
message_id=self.message.id,
@ -229,10 +214,10 @@ class ConversationMessageTask:
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
@ -253,7 +238,7 @@ class ConversationMessageTask:
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = llm_constant.model_currency
message_agent_thought.currency = agent_model_instant.get_currency()
db.session.flush()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence
from langchain.schema import Document
from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
@ -13,12 +13,10 @@ class DatesetDocumentStore:
self,
dataset: Dataset,
user_id: str,
embedding_model_name: str,
document_id: Optional[str] = None,
):
self._dataset = dataset
self._user_id = user_id
self._embedding_model_name = embedding_model_name
self._document_id = document_id
@classmethod
@ -39,10 +37,6 @@ class DatesetDocumentStore:
def user_id(self) -> Any:
return self._user_id
@property
def embedding_model_name(self) -> Any:
return self._embedding_model_name
@property
def docs(self) -> Dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter(
@ -74,6 +68,10 @@ class DatesetDocumentStore:
if max_position is None:
max_position = 0
embedding_model = ModelFactory.get_embedding_model(
tenant_id=self._dataset.tenant_id
)
for doc in docs:
if not isinstance(doc, Document):
raise ValueError("doc must be a Document")
@ -88,7 +86,7 @@ class DatesetDocumentStore:
)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
tokens = embedding_model.get_num_tokens(doc.page_content)
if not segment_document:
max_position += 1

View File

@ -4,14 +4,14 @@ from typing import List
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from core.model_providers.models.embedding.base import BaseEmbedding
from extensions.ext_database import db
from libs import helper
from models.dataset import Embedding
class CacheEmbedding(Embeddings):
def __init__(self, embeddings: Embeddings):
def __init__(self, embeddings: BaseEmbedding):
self._embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
@ -21,48 +21,54 @@ class CacheEmbedding(Embeddings):
embedding_queue_texts = []
for text in texts:
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
if embedding:
text_embeddings.append(embedding.get_embedding())
else:
embedding_queue_texts.append(text)
embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
if embedding_queue_texts:
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
finally:
i += 1
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
text_embeddings.extend(embedding_results)
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
finally:
i += 1
text_embeddings.extend(embedding_results)
return text_embeddings
@handle_openai_exceptions
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
if embedding:
return embedding.get_embedding()
embedding_results = self._embeddings.embed_query(text)
try:
embedding_results = self._embeddings.client.embed_query(text)
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
try:
embedding = Embedding(hash=hash)
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results)
db.session.add(embedding)
db.session.commit()
@ -72,3 +78,5 @@ class CacheEmbedding(Embeddings):
logging.exception('Failed to add embedding to db')
return embedding_results

View File

@ -1,13 +1,10 @@
import logging
from langchain import PromptTemplate
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
from langchain.schema import OutputParserException
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_open_ai import StreamableOpenAI
from core.llm.token_calculator import TokenCalculator
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@ -15,9 +12,6 @@ from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTempla
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
GENERATOR_QA_PROMPT
# gpt-3.5-turbo works not well
generate_base_model = 'text-davinci-003'
class LLMGenerator:
@classmethod
@ -28,29 +22,35 @@ class LLMGenerator:
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
prompt = prompt.format(query=query)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=50,
timeout=600
model_kwargs=ModelKwargs(
max_tokens=50
)
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
model = 'gpt-3.5-turbo'
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=max_tokens
)
)
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = ''
for message in messages:
@ -68,25 +68,16 @@ class LLMGenerator:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
context += message_qa_text
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=model,
max_tokens=max_tokens
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
@ -94,16 +85,13 @@ class LLMGenerator:
prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
@ -119,23 +107,19 @@ class LLMGenerator:
_input = prompt.format_prompt(histories=histories)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=256
model_kwargs=ModelKwargs(
max_tokens=256,
temperature=0
)
)
if isinstance(llm, BaseChatModel):
query = [HumanMessage(content=_input.to_string())]
else:
query = _input.to_string()
prompts = [PromptMessage(content=_input.to_string())]
try:
output = llm(query)
if isinstance(output, BaseMessage):
output = output.content
questions = output_parser.parse(output)
output = model_instance.run(prompts)
questions = output_parser.parse(output.content)
except Exception:
logging.exception("Error generating suggested questions after answer")
questions = []
@ -160,21 +144,19 @@ class LLMGenerator:
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name=generate_base_model,
temperature=0,
max_tokens=512
model_kwargs=ModelKwargs(
max_tokens=512,
temperature=0
)
)
if isinstance(llm, BaseChatModel):
query = [HumanMessage(content=_input.to_string())]
else:
query = _input.to_string()
prompts = [PromptMessage(content=_input.to_string())]
try:
output = llm(query)
rule_config = output_parser.parse(output)
output = model_instance.run(prompts)
rule_config = output_parser.parse(output.content)
except OutputParserException:
raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
except Exception:
@ -188,25 +170,21 @@ class LLMGenerator:
return rule_config
@classmethod
async def generate_qa_document(cls, llm: StreamableOpenAI, query):
def generate_qa_document(cls, tenant_id: str, query):
prompt = GENERATOR_QA_PROMPT
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=2000
)
)
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
prompts = [
PromptMessage(content=prompt, type=MessageType.SYSTEM),
PromptMessage(content=query)
]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip()
@classmethod
def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
prompt = GENERATOR_QA_PROMPT
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
response = model_instance.run(prompts)
answer = response.content
return answer.strip()

View File

@ -0,0 +1,20 @@
import base64
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
def obfuscated_token(token: str):
return token[:6] + '*' * (len(token) - 8) + token[-2:]
def encrypt_token(tenant_id: str, token: str):
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(tenant_id: str, token: str):
return rsa.decrypt(base64.b64decode(token), tenant_id)

View File

@ -1,10 +1,9 @@
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from core.model_providers.model_factory import ModelFactory
from models.dataset import Dataset
@ -15,16 +14,11 @@ class IndexBuilder:
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
max_retries=1,
**model_credentials
))
embeddings = CacheEmbedding(embedding_model)
return VectorIndex(
dataset=dataset,

View File

@ -1,4 +1,3 @@
import concurrent
import datetime
import json
import logging
@ -6,7 +5,6 @@ import re
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, List, cast
from flask_login import current_user
@ -18,11 +16,10 @@ from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.generator.llm_generator import LLMGenerator
from core.index.index import IndexBuilder
from core.llm.error import ProviderTokenNotInitError
from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_open_ai import StreamableOpenAI
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import MessageType
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -35,9 +32,8 @@ from models.source import DataSourceBinding
class IndexingRunner:
def __init__(self, embedding_model_name: str = "text-embedding-ada-002"):
def __init__(self):
self.storage = storage
self.embedding_model_name = embedding_model_name
def run(self, dataset_documents: List[DatasetDocument]):
"""Run the indexing process."""
@ -227,11 +223,15 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None) -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
tokens = 0
preview_texts = []
total_segments = 0
@ -253,44 +253,49 @@ class IndexingRunner:
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
self.filter_string(document.page_content))
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0:
# qa model document
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=current_user.current_tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
}
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"currency": embedding_model.get_currency(),
"preview": preview_texts
}
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# load data from notion
tokens = 0
preview_texts = []
@ -336,31 +341,31 @@ class IndexingRunner:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
tokens += embedding_model.get_num_tokens(document.page_content)
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0:
# qa model document
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=current_user.current_tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
}
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"currency": embedding_model.get_currency(),
"preview": preview_texts
}
@ -459,7 +464,6 @@ class IndexingRunner:
doc_store = DatesetDocumentStore(
dataset=dataset,
user_id=dataset_document.created_by,
embedding_model_name=self.embedding_model_name,
document_id=dataset_document.id
)
@ -513,17 +517,12 @@ class IndexingRunner:
all_documents.extend(split_documents)
# processing qa document
if document_form == 'qa_model':
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i:i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents})
'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
@ -531,13 +530,13 @@ class IndexingRunner:
return all_qa_documents
return all_documents
def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents):
def format_qa_document(self, tenant_id: str, document_node, all_qa_documents):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
try:
# qa model document
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
@ -638,6 +637,10 @@ class IndexingRunner:
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
@ -648,7 +651,7 @@ class IndexingRunner:
chunk_documents = documents[i:i + chunk_size]
tokens += sum(
TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
embedding_model.get_num_tokens(document.page_content)
for document in chunk_documents
)

View File

@ -1,148 +0,0 @@
from typing import Union, Optional, List
from langchain.callbacks.base import BaseCallbackHandler
from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType, ProviderName
class LLMBuilder:
"""
This class handles the following logic:
1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
OPENAI_API_TYPE=azure
OPENAI_API_VERSION=2022-12-01
OPENAI_API_BASE=https://your-resource-name.openai.azure.com
OPENAI_API_KEY=<your Azure OpenAI API key>
3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
"""
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
llm_cls = None
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableChatOpenAI
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureChatOpenAI
elif provider == ProviderName.ANTHROPIC.value:
llm_cls = StreamableChatAnthropic
elif mode == 'completion':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableOpenAI
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureOpenAI
if not llm_cls:
raise ValueError(f"model name {model_name} is not supported.")
model_kwargs = {
'model_name': model_name,
'temperature': kwargs.get('temperature', 0),
'max_tokens': kwargs.get('max_tokens', 256),
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
'callbacks': kwargs.get('callbacks', None),
'streaming': kwargs.get('streaming', False),
}
model_kwargs.update(model_credentials)
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
return llm_cls(**model_kwargs)
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name")
completion_params = model.get("completion_params", {})
return cls.to_llm(
tenant_id=tenant_id,
model_name=model_name,
temperature=completion_params.get('temperature', 0),
max_tokens=completion_params.get('max_tokens', 256),
top_p=completion_params.get('top_p', 0),
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming,
callbacks=callbacks
)
@classmethod
def get_mode_by_model(cls, model_name: str) -> str:
if not model_name:
raise ValueError(f"empty model name is not supported.")
if model_name in llm_constant.models_by_mode['chat']:
return "chat"
elif model_name in llm_constant.models_by_mode['completion']:
return "completion"
else:
raise ValueError(f"model name {model_name} is not supported.")
@classmethod
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
"""
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Raises an exception if the model_name is not found or if the provider is not found.
"""
if not model_name:
raise Exception('model name not found')
#
# if model_name not in llm_constant.models:
# raise Exception('model {} not found'.format(model_name))
# model_provider = llm_constant.models[model_name]
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
return provider_service.get_credentials(model_name)
@classmethod
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
provider_name = llm_constant.models[model_name]
if provider_name == 'openai':
# get the default provider (openai / azure_openai) for the tenant
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = azure_openai_provider
elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = azure_openai_provider
if not provider:
raise ProviderTokenNotInitError(
f"No valid {provider_name} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
provider_name = provider.provider_name
return provider_name

View File

@ -1,15 +0,0 @@
import openai
from models.provider import ProviderName
class Moderation:
def __init__(self, provider: str, api_key: str):
self.provider = provider
self.api_key = api_key
if self.provider == ProviderName.OPENAI.value:
self.client = openai.Moderation
def moderate(self, text):
return self.client.create(input=text, api_key=self.api_key)

View File

@ -1,138 +0,0 @@
import json
import logging
from typing import Optional, Union
import anthropic
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName, ProviderType
class AnthropicProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
return [
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
},
{
'id': 'claude-2',
'name': 'claude-2',
},
]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.get_provider_api_key(model_id=model_id)
def get_provider_name(self):
return ProviderName.ANTHROPIC
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'anthropic_api_key': ''
}
if obfuscated:
if not config.get('anthropic_api_key'):
config = {
'anthropic_api_key': ''
}
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
return config
return config
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
# check OpenAI / Azure OpenAI credential is valid
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider:
provider = openai_provider
elif azure_openai_provider:
provider = azure_openai_provider
if not provider:
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if quota_used >= quota_limit:
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
f"please configure OpenAI or Azure OpenAI provider first.")
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'anthropic_api_key' not in config:
raise ValueError('anthropic_api_key must be provided.')
chat_llm = ChatAnthropic(
model='claude-instant-1',
anthropic_api_key=config['anthropic_api_key'],
max_tokens_to_sample=10,
temperature=0,
default_request_timeout=60
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except anthropic.APIConnectionError as ex:
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
except Exception as ex:
logging.exception('Anthropic config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}

View File

@ -1,145 +0,0 @@
import json
import logging
from typing import Optional, Union
import openai
import requests
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
return []
def check_embedding_model(self, credentials: Optional[dict] = None):
credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
try:
result = openai.Embedding.create(input=['test'],
engine='text-embedding-ada-002',
timeout=60,
api_key=str(credentials.get('openai_api_key')),
api_base=str(credentials.get('openai_api_base')),
api_type='azure',
api_version=str(credentials.get('openai_api_version')))["data"][0][
"embedding"]
except openai.error.AuthenticationError as e:
raise AzureAuthenticationError(str(e))
except openai.error.APIConnectionError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
except openai.error.InvalidRequestError as e:
if e.http_status == 404:
raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
"deployment name is exists in Azure AI")
else:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
except openai.error.OpenAIError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
if not isinstance(result, list):
raise AzureRequestFailedError('Failed to request Azure OpenAI.')
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Azure OpenAI as a dictionary.
"""
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
config['chunk_size'] = 16
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
def get_provider_name(self):
return ProviderName.AZURE_OPENAI
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
if obfuscated:
if not config.get('openai_api_key'):
config = {
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
return config
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'openai_api_version' not in config:
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
self.check_embedding_model(credentials=config)
except ValidateFailedError as e:
raise e
except AzureAuthenticationError:
raise ValidateFailedError('Validation failed, please check your API Key.')
except AzureRequestFailedError as ex:
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
except Exception as ex:
logging.exception('Azure OpenAI Credentials validation failed')
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': config['openai_api_base'],
'openai_api_key': self.encrypt_token(config['openai_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
return config
class AzureAuthenticationError(Exception):
pass
class AzureRequestFailedError(Exception):
pass

View File

@ -1,132 +0,0 @@
import base64
from abc import ABC, abstractmethod
from typing import Optional, Union
from core.constant import llm_constant
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.provider import Provider, ProviderType, ProviderName
class BaseProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
"""
provider = self.get_provider(only_custom)
if not provider:
raise ProviderTokenNotInitError(
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if model_id and model_id == 'gpt-4':
raise ModelCurrentlyNotSupportError()
if quota_used >= quota_limit:
raise QuotaExceededError()
return self.get_hosted_credentials()
else:
return self.get_decrypted_token(provider.encrypted_config)
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
@classmethod
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist.
"""
query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id
)
if provider_name:
query = query.filter(Provider.provider_name == provider_name)
if only_custom:
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
providers = query.order_by(Provider.provider_type.asc()).all()
for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
return provider
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
return provider
return None
def get_hosted_credentials(self) -> Union[str | dict]:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = ''
if obfuscated:
return self.obfuscated_token(config)
return config
def obfuscated_token(self, token: str):
return token[:6] + '*' * (len(token) - 8) + token[-2:]
def get_token_type(self):
return str
def get_encrypted_token(self, config: Union[dict | str]):
return self.encrypt_token(config)
def get_decrypted_token(self, token: str):
return self.decrypt_token(token)
def encrypt_token(self, token):
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token):
return rsa.decrypt(base64.b64decode(token), self.tenant_id)
@abstractmethod
def get_provider_name(self):
raise NotImplementedError
@abstractmethod
def get_credentials(self, model_id: Optional[str] = None) -> dict:
raise NotImplementedError
@abstractmethod
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
raise NotImplementedError
@abstractmethod
def config_validate(self, config: str):
raise NotImplementedError

View File

@ -1,2 +0,0 @@
class ValidateFailedError(Exception):
description = "Provider Validate failed"

View File

@ -1,22 +0,0 @@
from typing import Optional
from core.llm.provider.base import BaseProvider
from models.provider import ProviderName
class HuggingfaceProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
# todo
return []
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
"""
return {
'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self):
return ProviderName.HUGGINGFACEHUB

View File

@ -1,53 +0,0 @@
from typing import Optional, Union
from core.llm.provider.anthropic_provider import AnthropicProvider
from core.llm.provider.azure_provider import AzureProvider
from core.llm.provider.base import BaseProvider
from core.llm.provider.huggingface_provider import HuggingfaceProvider
from core.llm.provider.openai_provider import OpenAIProvider
from models.provider import Provider
class LLMProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self.init_provider(tenant_id, provider_name)
def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
if provider_name == 'openai':
return OpenAIProvider(tenant_id)
elif provider_name == 'azure_openai':
return AzureProvider(tenant_id)
elif provider_name == 'anthropic':
return AnthropicProvider(tenant_id)
elif provider_name == 'huggingface':
return HuggingfaceProvider(tenant_id)
else:
raise Exception('provider {} not found'.format(provider_name))
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
return self.provider.get_models(model_id)
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.provider.get_credentials(model_id)
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
def get_provider_db_record(self) -> Optional[Provider]:
return self.provider.get_provider()
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
:param config:
:raises: ValidateFailedError
"""
return self.provider.config_validate(config)
def get_token_type(self):
return self.provider.get_token_type()
def get_encrypted_token(self, config: Union[dict | str]):
return self.provider.get_encrypted_token(config)

View File

@ -1,55 +0,0 @@
import logging
from typing import Optional, Union
import openai
from openai.error import AuthenticationError, OpenAIError
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.moderation import Moderation
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName
class OpenAIProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
response = openai.Model.list(**credentials)
return [{
'id': model['id'],
'name': model['id'],
} for model in response['data']]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the credentials for the given tenant_id and provider_name.
"""
return {
'openai_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self):
return ProviderName.OPENAI
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
try:
Moderation(self.get_provider_name().value, config).moderate('test')
except (AuthenticationError, OpenAIError) as ex:
raise ValidateFailedError(str(ex))
except Exception as ex:
logging.exception('OpenAI config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return hosted_llm_credentials.openai.api_key

View File

@ -1,62 +0,0 @@
from typing import List, Optional, Any, Dict
from httpx import Timeout
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
from pydantic import root_validator
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
class StreamableChatAnthropic(ChatAnthropic):
"""
Wrapper around Anthropic's large language model.
"""
default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
@root_validator()
def prepare_params(cls, values: Dict) -> Dict:
values['model_name'] = values.get('model')
values['max_tokens'] = values.get('max_tokens_to_sample')
return values
@handle_anthropic_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
params['model'] = params.get('model_name')
del params['model_name']
params['max_tokens_to_sample'] = params.get('max_tokens')
del params['max_tokens']
del params['frequency_penalty']
del params['presence_penalty']
return params
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text

View File

@ -1,41 +0,0 @@
import decimal
from typing import Optional
import tiktoken
from core.constant import llm_constant
class TokenCalculator:
@classmethod
def get_num_tokens(cls, model_name: str, text: str):
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(model_name)
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
@classmethod
def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal:
if model_name in llm_constant.models_by_mode['embedding']:
unit_price = llm_constant.model_prices[model_name]['usage']
elif text_type == 'prompt':
unit_price = llm_constant.model_prices[model_name]['prompt']
elif text_type == 'completion':
unit_price = llm_constant.model_prices[model_name]['completion']
else:
raise Exception('Invalid text type')
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
@classmethod
def get_currency(cls, model_name: str):
return llm_constant.model_currency

View File

@ -1,26 +0,0 @@
import openai
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from models.provider import ProviderName
from core.llm.provider.base import BaseProvider
class Whisper:
def __init__(self, provider: BaseProvider):
self.provider = provider
if self.provider.get_provider_name() == ProviderName.OPENAI:
self.client = openai.Audio
self.credentials = provider.get_credentials()
@handle_openai_exceptions
def transcribe(self, file):
return self.client.transcribe(
model='whisper-1',
file=file,
api_key=self.credentials.get('openai_api_key'),
api_base=self.credentials.get('openai_api_base'),
api_type=self.credentials.get('openai_api_type'),
api_version=self.credentials.get('openai_api_version'),
)

View File

@ -1,27 +0,0 @@
import logging
from functools import wraps
import anthropic
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_anthropic_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except anthropic.APIConnectionError as e:
logging.exception("Failed to connect to Anthropic API.")
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
except anthropic.RateLimitError:
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
except anthropic.AuthenticationError as e:
raise LLMAuthorizationError(f"Anthropic: {e.message}")
except anthropic.BadRequestError as e:
raise LLMBadRequestError(f"Anthropic: {e.message}")
except anthropic.APIStatusError as e:
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
return wrapper

View File

@ -1,31 +0,0 @@
import logging
from functools import wraps
import openai
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_openai_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except openai.error.InvalidRequestError as e:
logging.exception("Invalid request to OpenAI API.")
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
except openai.error.OpenAIError as e:
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper

View File

@ -1,10 +1,10 @@
from typing import Any, List, Dict, Union
from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
from langchain.schema import get_buffer_string, BaseMessage
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
from core.model_providers.models.llm.base import BaseLLM
from extensions.ext_database import db
from models.model import Conversation, Message
@ -13,7 +13,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation
human_prefix: str = "Human"
ai_prefix: str = "Assistant"
llm: BaseLanguageModel
model_instance: BaseLLM
memory_key: str = "chat_history"
max_token_limit: int = 2000
message_limit: int = 10
@ -29,23 +29,23 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
messages = list(reversed(messages))
chat_messages: List[BaseMessage] = []
chat_messages: List[PromptMessage] = []
for message in messages:
chat_messages.append(HumanMessage(content=message.query))
chat_messages.append(AIMessage(content=message.answer))
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:
return chat_messages
return []
# prune the chat message if it exceeds the max token limit
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
return chat_messages
return to_lc_messages(chat_messages)
@property
def memory_variables(self) -> List[str]:

View File

@ -0,0 +1,293 @@
from typing import Optional
from langchain.callbacks.base import Callbacks
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db
from models.provider import TenantDefaultModel
class ModelFactory:
@classmethod
def get_text_generation_model_from_model_config(cls, tenant_id: str,
model_config: dict,
streaming: bool = False,
callbacks: Callbacks = None) -> Optional[BaseLLM]:
provider_name = model_config.get("provider")
model_name = model_config.get("name")
completion_params = model_config.get("completion_params", {})
return cls.get_text_generation_model(
tenant_id=tenant_id,
model_provider_name=provider_name,
model_name=model_name,
model_kwargs=ModelKwargs(
temperature=completion_params.get('temperature', 0),
max_tokens=completion_params.get('max_tokens', 256),
top_p=completion_params.get('top_p', 0),
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1)
),
streaming=streaming,
callbacks=callbacks
)
@classmethod
def get_text_generation_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None,
model_kwargs: Optional[ModelKwargs] = None,
streaming: bool = False,
callbacks: Callbacks = None) -> Optional[BaseLLM]:
"""
get text generation model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:param model_kwargs:
:param streaming:
:param callbacks:
:return:
"""
is_default_model = False
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default System Reasoning Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
is_default_model = True
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init text generation model
model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
try:
model_instance = model_class(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs,
streaming=streaming,
callbacks=callbacks
)
except LLMBadRequestError as e:
if is_default_model:
raise LLMBadRequestError(f"Default model {model_name} is not available. "
f"Please check your model provider credentials.")
else:
raise e
if is_default_model:
model_instance.deduct_quota = False
return model_instance
@classmethod
def get_embedding_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
"""
get embedding model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Embedding Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init embedding model
model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_speech2text_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
"""
get speech to text model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Speech-to-Text Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init speech to text model
model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_moderation_model(cls,
tenant_id: str,
model_provider_name: str,
model_name: str) -> Optional[BaseProviderModel]:
"""
get moderation model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init moderation model
model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
"""
get default model of model type.
:param tenant_id:
:param model_type:
:return:
"""
# get default model
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.value
).first()
if not default_model:
model_provider_rules = ModelProviderFactory.get_provider_rules()
for model_provider_name, model_provider_rule in model_provider_rules.items():
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
continue
model_list = model_provider.get_supported_model_list(model_type)
if model_list:
model_info = model_list[0]
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
provider_name=model_provider_name,
model_name=model_info['id']
)
db.session.add(default_model)
db.session.commit()
break
return default_model
@classmethod
def update_default_model(cls,
tenant_id: str,
model_type: ModelType,
provider_name: str,
model_name: str) -> TenantDefaultModel:
"""
update default model of model type.
:param tenant_id:
:param model_type:
:param provider_name:
:param model_name:
:return:
"""
model_provider_name = ModelProviderFactory.get_provider_names()
if provider_name not in model_provider_name:
raise ValueError(f'Invalid provider name: {provider_name}')
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
model_list = model_provider.get_supported_model_list(model_type)
model_ids = [model['id'] for model in model_list]
if model_name not in model_ids:
raise ValueError(f'Invalid model name: {model_name}')
# get default model
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.value
).first()
if default_model:
# update default model
default_model.provider_name = provider_name
default_model.model_name = model_name
db.session.commit()
else:
# create default model
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
provider_name=provider_name,
model_name=model_name,
)
db.session.add(default_model)
db.session.commit()
return default_model

View File

@ -0,0 +1,228 @@
from typing import Type
from sqlalchemy.exc import IntegrityError
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.rules import provider_rules
from extensions.ext_database import db
from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
DEFAULT_MODELS = {
ModelType.TEXT_GENERATION.value: {
'provider_name': 'openai',
'model_name': 'gpt-3.5-turbo',
},
ModelType.EMBEDDINGS.value: {
'provider_name': 'openai',
'model_name': 'text-embedding-ada-002',
},
ModelType.SPEECH_TO_TEXT.value: {
'provider_name': 'openai',
'model_name': 'whisper-1',
}
}
class ModelProviderFactory:
@classmethod
def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
if provider_name == 'openai':
from core.model_providers.providers.openai_provider import OpenAIProvider
return OpenAIProvider
elif provider_name == 'anthropic':
from core.model_providers.providers.anthropic_provider import AnthropicProvider
return AnthropicProvider
elif provider_name == 'minimax':
from core.model_providers.providers.minimax_provider import MinimaxProvider
return MinimaxProvider
elif provider_name == 'spark':
from core.model_providers.providers.spark_provider import SparkProvider
return SparkProvider
elif provider_name == 'tongyi':
from core.model_providers.providers.tongyi_provider import TongyiProvider
return TongyiProvider
elif provider_name == 'wenxin':
from core.model_providers.providers.wenxin_provider import WenxinProvider
return WenxinProvider
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider
elif provider_name == 'azure_openai':
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
return AzureOpenAIProvider
elif provider_name == 'replicate':
from core.model_providers.providers.replicate_provider import ReplicateProvider
return ReplicateProvider
elif provider_name == 'huggingface_hub':
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
return HuggingfaceHubProvider
else:
raise NotImplementedError
@classmethod
def get_provider_names(cls):
"""
Returns a list of provider names.
"""
return list(provider_rules.keys())
@classmethod
def get_provider_rules(cls):
"""
Returns a list of provider rules.
:return:
"""
return provider_rules
@classmethod
def get_provider_rule(cls, provider_name: str):
"""
Returns provider rule.
"""
return provider_rules[provider_name]
@classmethod
def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred model provider.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:return:
"""
# get preferred provider
preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
if not preferred_provider or not preferred_provider.is_valid:
return None
# init model provider
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
return model_provider_class(provider=preferred_provider)
@classmethod
def get_preferred_type_by_preferred_model_provider(cls,
tenant_id: str,
model_provider_name: str,
preferred_model_provider: TenantPreferredModelProvider):
"""
get preferred provider type by preferred model provider.
:param model_provider_name:
:param preferred_model_provider:
:return:
"""
if not preferred_model_provider:
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
support_provider_types = model_provider_rules['support_provider_types']
if ProviderType.CUSTOM.value in support_provider_types:
custom_provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.is_valid == True
).first()
if custom_provider:
return ProviderType.CUSTOM.value
model_provider = cls.get_model_provider_class(model_provider_name)
if ProviderType.SYSTEM.value in support_provider_types \
and model_provider.is_provider_type_system_supported():
return ProviderType.SYSTEM.value
elif ProviderType.CUSTOM.value in support_provider_types:
return ProviderType.CUSTOM.value
else:
return preferred_model_provider.preferred_provider_type
@classmethod
def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
# get preferred provider type
preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
# get providers by preferred provider type
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == preferred_provider_type
).all()
no_system_provider = False
if preferred_provider_type == ProviderType.SYSTEM.value:
quota_type_to_provider_dict = {}
for provider in providers:
quota_type_to_provider_dict[provider.quota_type] = provider
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
for quota_type_enum in ProviderQuotaType:
quota_type = quota_type_enum.value
if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
and quota_type in quota_type_to_provider_dict.keys():
provider = quota_type_to_provider_dict[quota_type]
if provider.is_valid and provider.quota_limit > provider.quota_used:
return provider
no_system_provider = True
if no_system_provider:
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
if providers:
return providers[0]
else:
try:
provider = Provider(
tenant_id=tenant_id,
provider_name=model_provider_name,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(provider)
db.session.commit()
except IntegrityError:
db.session.rollback()
provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
return provider
return None
@classmethod
def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider type of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name == model_provider_name
).first()
return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)

View File

@ -0,0 +1,22 @@
from abc import ABC
from typing import Any
from core.model_providers.providers.base import BaseModelProvider
class BaseProviderModel(ABC):
_client: Any
_model_provider: BaseModelProvider
def __init__(self, model_provider: BaseModelProvider, client: Any):
self._model_provider = model_provider
self._client = client
@property
def client(self):
return self._client
@property
def model_provider(self):
return self._model_provider

View File

@ -0,0 +1,78 @@
import decimal
import logging
import openai
import tiktoken
from langchain.embeddings import OpenAIEmbeddings
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \
LLMAPIUnavailableError, LLMAPIConnectionError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureOpenAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = OpenAIEmbeddings(
deployment=name,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
chunk_size=16,
max_retries=1,
**self.credentials
)
super().__init__(model_provider, client, name)
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name'))
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to Azure OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("Azure OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@ -0,0 +1,40 @@
from abc import abstractmethod
from typing import Any
import tiktoken
from langchain.schema.language_model import _get_token_ids_default_method
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseEmbedding(BaseProviderModel):
name: str
type: ModelType = ModelType.EMBEDDINGS
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
return len(_get_token_ids_default_method(text))
def get_token_price(self, tokens: int):
return 0
def get_currency(self):
return 'USD'
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError

View File

@ -0,0 +1,35 @@
import decimal
import logging
from langchain.embeddings import MiniMaxEmbeddings
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
class MinimaxEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = MiniMaxEmbeddings(
model=name,
**credentials
)
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex

View File

@ -0,0 +1,72 @@
import decimal
import logging
import openai
import tiktoken
from langchain.embeddings import OpenAIEmbeddings
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
class OpenAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = OpenAIEmbeddings(
max_retries=1,
**credentials
)
super().__init__(model_provider, client, name)
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(self.name)
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@ -0,0 +1,36 @@
import decimal
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings
from core.model_providers.models.embedding.base import BaseEmbedding
class ReplicateEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = ReplicateEmbeddings(
model=name + ':' + credentials.get('model_version'),
replicate_api_token=credentials.get('replicate_api_token')
)
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}")
else:
return ex

View File

@ -0,0 +1,53 @@
import enum
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
from pydantic import BaseModel
class LLMRunResult(BaseModel):
content: str
prompt_tokens: int
completion_tokens: int
class MessageType(enum.Enum):
HUMAN = 'human'
ASSISTANT = 'assistant'
SYSTEM = 'system'
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
content: str = ''
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
lc_messages.append(SystemMessage(content=message.content))
return lc_messages
def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
return prompt_messages
def str_to_prompt_messages(texts: list[str]):
prompt_messages = []
for text in texts:
prompt_messages.append(PromptMessage(content=text))
return prompt_messages

View File

@ -0,0 +1,59 @@
import enum
from typing import Optional, TypeVar, Generic
from langchain.load.serializable import Serializable
from pydantic import BaseModel
class ModelMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'
class ModelType(enum.Enum):
TEXT_GENERATION = 'text-generation'
EMBEDDINGS = 'embeddings'
SPEECH_TO_TEXT = 'speech2text'
IMAGE = 'image'
VIDEO = 'video'
MODERATION = 'moderation'
@staticmethod
def value_of(value):
for member in ModelType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ModelKwargs(BaseModel):
max_tokens: Optional[int]
temperature: Optional[float]
top_p: Optional[float]
presence_penalty: Optional[float]
frequency_penalty: Optional[float]
class KwargRuleType(enum.Enum):
STRING = 'string'
INTEGER = 'integer'
FLOAT = 'float'
T = TypeVar('T')
class KwargRule(Generic[T], BaseModel):
enabled: bool = True
min: Optional[T] = None
max: Optional[T] = None
default: Optional[T] = None
alias: Optional[str] = None
class ModelKwargsRules(BaseModel):
max_tokens: KwargRule = KwargRule[int](enabled=False)
temperature: KwargRule = KwargRule[float](enabled=False)
top_p: KwargRule = KwargRule[float](enabled=False)
presence_penalty: KwargRule = KwargRule[float](enabled=False)
frequency_penalty: KwargRule = KwargRule[float](enabled=False)

View File

@ -0,0 +1,10 @@
from enum import Enum
class ProviderQuotaUnit(Enum):
TIMES = 'times'
TOKENS = 'tokens'
class ModelFeature(Enum):
AGENT_THOUGHT = 'agent_thought'

View File

@ -0,0 +1,107 @@
import decimal
import logging
from functools import wraps
from typing import List, Optional, Any
import anthropic
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class AnthropicModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatAnthropic(
model=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
default_request_timeout=60,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'claude-instant-1': {
'prompt': decimal.Decimal('1.63'),
'completion': decimal.Decimal('5.51'),
},
'claude-2': {
'prompt': decimal.Decimal('11.02'),
'completion': decimal.Decimal('32.68'),
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1m * unit_price
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, anthropic.APIConnectionError):
logging.warning("Failed to connect to Anthropic API.")
return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}")
elif isinstance(ex, anthropic.RateLimitError):
return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
elif isinstance(ex, anthropic.AuthenticationError):
return LLMAuthorizationError(f"Anthropic: {ex.message}")
elif isinstance(ex, anthropic.BadRequestError):
return LLMBadRequestError(f"Anthropic: {ex.message}")
elif isinstance(ex, anthropic.APIStatusError):
return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@ -0,0 +1,177 @@
import decimal
import logging
from functools import wraps
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureOpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
if name == 'text-davinci-003':
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.name == 'text-davinci-003':
client = EnhanceAzureOpenAI(
deployment_name=self.name,
streaming=self.streaming,
request_timeout=60,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
openai_api_key=self.credentials.get('openai_api_key'),
openai_api_base=self.credentials.get('openai_api_base'),
callbacks=self.callbacks,
**provider_model_kwargs
)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
client = EnhanceAzureChatOpenAI(
deployment_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
request_timeout=60,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
openai_api_key=self.credentials.get('openai_api_key'),
openai_api_base=self.credentials.get('openai_api_base'),
callbacks=self.callbacks,
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-35-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-35-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
base_model_name = self.credentials.get("base_model_name")
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[base_model_name]['prompt']
else:
unit_price = model_unit_prices[base_model_name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name == 'text-davinci-003':
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to Azure OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("Azure OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@ -0,0 +1,269 @@
from abc import abstractmethod
from typing import List, Optional, Any, Union
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.fake import FakeLLM
class BaseLLM(BaseProviderModel):
model_mode: ModelMode = ModelMode.COMPLETION
name: str
model_kwargs: ModelKwargs
credentials: dict
streaming: bool = False
type: ModelType = ModelType.TEXT_GENERATION
deduct_quota: bool = True
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
self.name = name
self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
max_tokens=None,
temperature=None,
top_p=None,
presence_penalty=None,
frequency_penalty=None
)
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
self.streaming = streaming
if streaming:
default_callback = DifyStreamingStdOutCallbackHandler()
else:
default_callback = DifyStdOutCallbackHandler()
if not callbacks:
callbacks = [default_callback]
else:
callbacks.append(default_callback)
self.callbacks = callbacks
client = self._init_client()
super().__init__(model_provider, client)
@abstractmethod
def _init_client(self) -> Any:
raise NotImplementedError
def run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMRunResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
if self.deduct_quota:
self.model_provider.check_quota_over_limit()
if not callbacks:
callbacks = self.callbacks
else:
callbacks.extend(self.callbacks)
if 'fake_response' in kwargs and kwargs['fake_response']:
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
response=kwargs['fake_response'],
num_token_func=self.get_num_tokens,
streaming=self.streaming,
callbacks=callbacks
)
result = fake_llm.generate([prompts])
else:
try:
result = self._run(
messages=messages,
stop=stop,
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
**kwargs
)
except Exception as ex:
raise self.handle_exceptions(ex)
if isinstance(result.generations[0][0], ChatGeneration):
completion_content = result.generations[0][0].message.content
else:
completion_content = result.generations[0][0].text
if self.streaming and not self.support_streaming():
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
response=completion_content,
num_token_func=self.get_num_tokens,
streaming=self.streaming,
callbacks=callbacks
)
fake_llm.generate([prompts])
if result.llm_output and result.llm_output['token_usage']:
prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
completion_tokens = result.llm_output['token_usage']['completion_tokens']
total_tokens = result.llm_output['token_usage']['total_tokens']
else:
prompt_tokens = self.get_num_tokens(messages)
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens
if self.deduct_quota:
self.model_provider.deduct_quota(total_tokens)
return LLMRunResult(
content=completion_content,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
)
@abstractmethod
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_token_price(self, tokens: int, message_type: MessageType):
"""
get token price.
:param tokens:
:param message_type:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_currency(self):
"""
get token currency.
:return:
"""
raise NotImplementedError
def get_model_kwargs(self):
return self.model_kwargs
def set_model_kwargs(self, model_kwargs: ModelKwargs):
self.model_kwargs = model_kwargs
self._set_model_kwargs(model_kwargs)
@abstractmethod
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
"""
Handle llm run exceptions.
:param ex:
:return:
"""
raise NotImplementedError
def add_callbacks(self, callbacks: Callbacks):
"""
Add callbacks to client.
:param callbacks:
:return:
"""
if not self.client.callbacks:
self.client.callbacks = callbacks
else:
self.client.callbacks.extend(callbacks)
@classmethod
def support_streaming(cls):
return False
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if len(messages) == 0:
raise ValueError("prompt must not be empty.")
if not model_mode:
model_mode = self.model_mode
if model_mode == ModelMode.COMPLETION:
return messages[0].content
else:
chat_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
chat_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
chat_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
chat_messages.append(SystemMessage(content=message.content))
return chat_messages
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
"""
convert model kwargs to provider model kwargs.
:param model_rules:
:param model_kwargs:
:return:
"""
model_kwargs_input = {}
for key, value in model_kwargs.dict().items():
rule = getattr(model_rules, key)
if not rule.enabled:
continue
if rule.alias:
key = rule.alias
if rule.default is not None and value is None:
value = rule.default
if rule.min is not None:
value = max(value, rule.min)
if rule.max is not None:
value = min(value, rule.max)
model_kwargs_input[key] = value
return model_kwargs_input

View File

@ -0,0 +1,70 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import ChatGLM
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class ChatGLMModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatGLM(
callbacks=self.callbacks,
endpoint_url=self.credentials.get('api_base'),
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return False

View File

@ -0,0 +1,82 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks
from langchain.llms import HuggingFaceEndpoint
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class HuggingfaceHubModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
client = HuggingFaceEndpoint(
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
task='text2text-generation',
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks,
)
else:
client = HuggingFaceHub(
repo_id=self.name,
task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks,
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# not support calc price
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@ -0,0 +1,70 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import Minimax
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class MinimaxModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return Minimax(
model=self.name,
model_kwargs={
'stream': False
},
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex

View File

@ -0,0 +1,219 @@
import decimal
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from models.provider import ProviderType, ProviderQuotaType
COMPLETION_MODELS = [
'text-davinci-003', # 4,097 tokens
]
CHAT_MODELS = [
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
]
MODEL_MAX_TOKENS = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
class OpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
if name in COMPLETION_MODELS:
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.name in COMPLETION_MODELS:
client = EnhanceOpenAI(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
**self.credentials,
**provider_model_kwargs
)
else:
# Fine-tuning is currently only available for the following base models:
# davinci, curie, babbage, and ada.
# This means that except for the fixed `completion` model,
# all other fine-tuned models are `completion` models.
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
**self.credentials
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
if self.name == 'gpt-4' \
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name in COMPLETION_MODELS:
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True
# def is_model_valid_or_raise(self):
# """
# check is a valid model.
#
# :return:
# """
# credentials = self._model_provider.get_credentials()
#
# try:
# result = openai.Model.retrieve(
# id=self.name,
# api_key=credentials.get('openai_api_key'),
# request_timeout=60
# )
#
# if 'id' not in result or result['id'] != self.name:
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.")
# except openai.error.OpenAIError as e:
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}")
# except Exception as e:
# logging.exception("OpenAI Model retrieve failed.")
# raise e

View File

@ -0,0 +1,103 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, get_buffer_string
from replicate.exceptions import ReplicateError, ModelError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.error import LLMBadRequestError
from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class ReplicateModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return EnhanceReplicate(
model=self.name + ':' + self.credentials.get('model_version'),
input=provider_model_kwargs,
streaming=self.streaming,
replicate_api_token=self.credentials.get('replicate_api_token'),
callbacks=self.callbacks,
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
extra_kwargs = {}
if isinstance(prompts, list):
system_messages = [message for message in messages if message.type == 'system']
if system_messages:
system_message = system_messages[0]
extra_kwargs['system_prompt'] = system_message.content
prompts = [message for message in messages if message.type != 'system']
prompts = get_buffer_string(prompts)
# The maximum length the generated tokens can have.
# Corresponds to the length of the input prompt + max_new_tokens.
if 'max_length' in self._client.input:
self._client.input['max_length'] = min(
self._client.input['max_length'] + self.get_num_tokens(messages),
self.model_rules.max_tokens.max
)
return self._client.generate([prompts], stop, callbacks, **extra_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, list):
prompts = get_buffer_string(prompts)
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.input = provider_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@ -0,0 +1,73 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.spark import ChatSpark
from core.third_party.spark.spark_llm import SparkError
class SparkModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatSpark(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
contents = [message.content for message in messages]
return max(self._client.get_num_tokens("".join(contents)), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, SparkError):
return LLMBadRequestError(f"Spark: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

Some files were not shown because too many files have changed in this diff Show More