mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Feat/dataset service api (#1245)
Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
parent
54ff03c35d
commit
46154c6705
|
@ -81,6 +81,7 @@ class BaseApiKeyListResource(Resource):
|
||||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||||
api_token = ApiToken()
|
api_token = ApiToken()
|
||||||
setattr(api_token, self.resource_id_field, resource_id)
|
setattr(api_token, self.resource_id_field, resource_id)
|
||||||
|
api_token.tenant_id = current_user.current_tenant_id
|
||||||
api_token.token = key
|
api_token.token = key
|
||||||
api_token.type = self.resource_type
|
api_token.type = self.resource_type
|
||||||
db.session.add(api_token)
|
db.session.add(api_token)
|
||||||
|
|
|
@ -19,41 +19,13 @@ from core.model_providers.model_factory import ModelFactory
|
||||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from core.model_providers.models.entity.model_params import ModelType
|
from core.model_providers.models.entity.model_params import ModelType
|
||||||
from events.app_event import app_was_created, app_was_deleted
|
from events.app_event import app_was_created, app_was_deleted
|
||||||
|
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
|
||||||
|
app_detail_fields_with_site
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, AppModelConfig, Site
|
from models.model import App, AppModelConfig, Site
|
||||||
from services.app_model_config_service import AppModelConfigService
|
from services.app_model_config_service import AppModelConfigService
|
||||||
|
|
||||||
model_config_fields = {
|
|
||||||
'opening_statement': fields.String,
|
|
||||||
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
|
|
||||||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
|
||||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
|
||||||
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
|
|
||||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
|
||||||
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
|
||||||
'model': fields.Raw(attribute='model_dict'),
|
|
||||||
'user_input_form': fields.Raw(attribute='user_input_form_list'),
|
|
||||||
'dataset_query_variable': fields.String,
|
|
||||||
'pre_prompt': fields.String,
|
|
||||||
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
|
|
||||||
}
|
|
||||||
|
|
||||||
app_detail_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'mode': fields.String,
|
|
||||||
'icon': fields.String,
|
|
||||||
'icon_background': fields.String,
|
|
||||||
'enable_site': fields.Boolean,
|
|
||||||
'enable_api': fields.Boolean,
|
|
||||||
'api_rpm': fields.Integer,
|
|
||||||
'api_rph': fields.Integer,
|
|
||||||
'is_demo': fields.Boolean,
|
|
||||||
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_app(app_id, tenant_id):
|
def _get_app(app_id, tenant_id):
|
||||||
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
|
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
|
||||||
|
@ -63,35 +35,6 @@ def _get_app(app_id, tenant_id):
|
||||||
|
|
||||||
|
|
||||||
class AppListApi(Resource):
|
class AppListApi(Resource):
|
||||||
prompt_config_fields = {
|
|
||||||
'prompt_template': fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
model_config_partial_fields = {
|
|
||||||
'model': fields.Raw(attribute='model_dict'),
|
|
||||||
'pre_prompt': fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
app_partial_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'mode': fields.String,
|
|
||||||
'icon': fields.String,
|
|
||||||
'icon_background': fields.String,
|
|
||||||
'enable_site': fields.Boolean,
|
|
||||||
'enable_api': fields.Boolean,
|
|
||||||
'is_demo': fields.Boolean,
|
|
||||||
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
app_pagination_fields = {
|
|
||||||
'page': fields.Integer,
|
|
||||||
'limit': fields.Integer(attribute='per_page'),
|
|
||||||
'total': fields.Integer,
|
|
||||||
'has_more': fields.Boolean(attribute='has_next'),
|
|
||||||
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -238,18 +181,6 @@ class AppListApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class AppTemplateApi(Resource):
|
class AppTemplateApi(Resource):
|
||||||
template_fields = {
|
|
||||||
'name': fields.String,
|
|
||||||
'icon': fields.String,
|
|
||||||
'icon_background': fields.String,
|
|
||||||
'description': fields.String,
|
|
||||||
'mode': fields.String,
|
|
||||||
'model_config': fields.Nested(model_config_fields),
|
|
||||||
}
|
|
||||||
|
|
||||||
template_list_fields = {
|
|
||||||
'data': fields.List(fields.Nested(template_fields)),
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -268,38 +199,6 @@ class AppTemplateApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class AppApi(Resource):
|
class AppApi(Resource):
|
||||||
site_fields = {
|
|
||||||
'access_token': fields.String(attribute='code'),
|
|
||||||
'code': fields.String,
|
|
||||||
'title': fields.String,
|
|
||||||
'icon': fields.String,
|
|
||||||
'icon_background': fields.String,
|
|
||||||
'description': fields.String,
|
|
||||||
'default_language': fields.String,
|
|
||||||
'customize_domain': fields.String,
|
|
||||||
'copyright': fields.String,
|
|
||||||
'privacy_policy': fields.String,
|
|
||||||
'customize_token_strategy': fields.String,
|
|
||||||
'prompt_public': fields.Boolean,
|
|
||||||
'app_base_url': fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
app_detail_fields_with_site = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'mode': fields.String,
|
|
||||||
'icon': fields.String,
|
|
||||||
'icon_background': fields.String,
|
|
||||||
'enable_site': fields.Boolean,
|
|
||||||
'enable_api': fields.Boolean,
|
|
||||||
'api_rpm': fields.Integer,
|
|
||||||
'api_rph': fields.Integer,
|
|
||||||
'is_demo': fields.Boolean,
|
|
||||||
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
|
||||||
'site': fields.Nested(site_fields),
|
|
||||||
'api_base_url': fields.String,
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
|
@ -13,107 +13,14 @@ from controllers.console import api
|
||||||
from controllers.console.app import _get_app
|
from controllers.console.app import _get_app
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
|
from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \
|
||||||
|
conversation_message_detail_fields, conversation_with_summary_pagination_fields
|
||||||
from libs.helper import TimestampField, datetime_string, uuid_value
|
from libs.helper import TimestampField, datetime_string, uuid_value
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import Message, MessageAnnotation, Conversation
|
from models.model import Message, MessageAnnotation, Conversation
|
||||||
|
|
||||||
account_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'email': fields.String
|
|
||||||
}
|
|
||||||
|
|
||||||
feedback_fields = {
|
|
||||||
'rating': fields.String,
|
|
||||||
'content': fields.String,
|
|
||||||
'from_source': fields.String,
|
|
||||||
'from_end_user_id': fields.String,
|
|
||||||
'from_account': fields.Nested(account_fields, allow_null=True),
|
|
||||||
}
|
|
||||||
|
|
||||||
annotation_fields = {
|
|
||||||
'content': fields.String,
|
|
||||||
'account': fields.Nested(account_fields, allow_null=True),
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
message_detail_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'conversation_id': fields.String,
|
|
||||||
'inputs': fields.Raw,
|
|
||||||
'query': fields.String,
|
|
||||||
'message': fields.Raw,
|
|
||||||
'message_tokens': fields.Integer,
|
|
||||||
'answer': fields.String,
|
|
||||||
'answer_tokens': fields.Integer,
|
|
||||||
'provider_response_latency': fields.Float,
|
|
||||||
'from_source': fields.String,
|
|
||||||
'from_end_user_id': fields.String,
|
|
||||||
'from_account_id': fields.String,
|
|
||||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
|
||||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
feedback_stat_fields = {
|
|
||||||
'like': fields.Integer,
|
|
||||||
'dislike': fields.Integer
|
|
||||||
}
|
|
||||||
|
|
||||||
model_config_fields = {
|
|
||||||
'opening_statement': fields.String,
|
|
||||||
'suggested_questions': fields.Raw,
|
|
||||||
'model': fields.Raw,
|
|
||||||
'user_input_form': fields.Raw,
|
|
||||||
'pre_prompt': fields.String,
|
|
||||||
'agent_mode': fields.Raw,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionConversationApi(Resource):
|
class CompletionConversationApi(Resource):
|
||||||
class MessageTextField(fields.Raw):
|
|
||||||
def format(self, value):
|
|
||||||
return value[0]['text'] if value else ''
|
|
||||||
|
|
||||||
simple_configs_fields = {
|
|
||||||
'prompt_template': fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
simple_model_config_fields = {
|
|
||||||
'model': fields.Raw(attribute='model_dict'),
|
|
||||||
'pre_prompt': fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
simple_message_detail_fields = {
|
|
||||||
'inputs': fields.Raw,
|
|
||||||
'query': fields.String,
|
|
||||||
'message': MessageTextField,
|
|
||||||
'answer': fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'status': fields.String,
|
|
||||||
'from_source': fields.String,
|
|
||||||
'from_end_user_id': fields.String,
|
|
||||||
'from_end_user_session_id': fields.String(),
|
|
||||||
'from_account_id': fields.String,
|
|
||||||
'read_at': TimestampField,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
|
||||||
'model_config': fields.Nested(simple_model_config_fields),
|
|
||||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
|
||||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
|
|
||||||
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation_pagination_fields = {
|
|
||||||
'page': fields.Integer,
|
|
||||||
'limit': fields.Integer(attribute='per_page'),
|
|
||||||
'total': fields.Integer,
|
|
||||||
'has_more': fields.Boolean(attribute='has_next'),
|
|
||||||
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -191,21 +98,11 @@ class CompletionConversationApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class CompletionConversationDetailApi(Resource):
|
class CompletionConversationDetailApi(Resource):
|
||||||
conversation_detail_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'status': fields.String,
|
|
||||||
'from_source': fields.String,
|
|
||||||
'from_end_user_id': fields.String,
|
|
||||||
'from_account_id': fields.String,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
'model_config': fields.Nested(model_config_fields),
|
|
||||||
'message': fields.Nested(message_detail_fields, attribute='first_message'),
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(conversation_detail_fields)
|
@marshal_with(conversation_message_detail_fields)
|
||||||
def get(self, app_id, conversation_id):
|
def get(self, app_id, conversation_id):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
@ -234,44 +131,11 @@ class CompletionConversationDetailApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class ChatConversationApi(Resource):
|
class ChatConversationApi(Resource):
|
||||||
simple_configs_fields = {
|
|
||||||
'prompt_template': fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
simple_model_config_fields = {
|
|
||||||
'model': fields.Raw(attribute='model_dict'),
|
|
||||||
'pre_prompt': fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'status': fields.String,
|
|
||||||
'from_source': fields.String,
|
|
||||||
'from_end_user_id': fields.String,
|
|
||||||
'from_end_user_session_id': fields.String,
|
|
||||||
'from_account_id': fields.String,
|
|
||||||
'summary': fields.String(attribute='summary_or_query'),
|
|
||||||
'read_at': TimestampField,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
'annotated': fields.Boolean,
|
|
||||||
'model_config': fields.Nested(simple_model_config_fields),
|
|
||||||
'message_count': fields.Integer,
|
|
||||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
|
||||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation_pagination_fields = {
|
|
||||||
'page': fields.Integer,
|
|
||||||
'limit': fields.Integer(attribute='per_page'),
|
|
||||||
'total': fields.Integer,
|
|
||||||
'has_more': fields.Boolean(attribute='has_next'),
|
|
||||||
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(conversation_pagination_fields)
|
@marshal_with(conversation_with_summary_pagination_fields)
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
|
|
||||||
|
@ -356,19 +220,6 @@ class ChatConversationApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class ChatConversationDetailApi(Resource):
|
class ChatConversationDetailApi(Resource):
|
||||||
conversation_detail_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'status': fields.String,
|
|
||||||
'from_source': fields.String,
|
|
||||||
'from_end_user_id': fields.String,
|
|
||||||
'from_account_id': fields.String,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
'annotated': fields.Boolean,
|
|
||||||
'model_config': fields.Nested(model_config_fields),
|
|
||||||
'message_count': fields.Integer,
|
|
||||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
|
||||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
|
@ -17,6 +17,7 @@ from controllers.console.wraps import account_initialization_required
|
||||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
from core.login.login import login_required
|
from core.login.login import login_required
|
||||||
|
from fields.conversation_fields import message_detail_fields
|
||||||
from libs.helper import uuid_value, TimestampField
|
from libs.helper import uuid_value, TimestampField
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -27,44 +28,6 @@ from services.errors.conversation import ConversationNotExistsError
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.message_service import MessageService
|
from services.message_service import MessageService
|
||||||
|
|
||||||
account_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'email': fields.String
|
|
||||||
}
|
|
||||||
|
|
||||||
feedback_fields = {
|
|
||||||
'rating': fields.String,
|
|
||||||
'content': fields.String,
|
|
||||||
'from_source': fields.String,
|
|
||||||
'from_end_user_id': fields.String,
|
|
||||||
'from_account': fields.Nested(account_fields, allow_null=True),
|
|
||||||
}
|
|
||||||
|
|
||||||
annotation_fields = {
|
|
||||||
'content': fields.String,
|
|
||||||
'account': fields.Nested(account_fields, allow_null=True),
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
message_detail_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'conversation_id': fields.String,
|
|
||||||
'inputs': fields.Raw,
|
|
||||||
'query': fields.String,
|
|
||||||
'message': fields.Raw,
|
|
||||||
'message_tokens': fields.Integer,
|
|
||||||
'answer': fields.String,
|
|
||||||
'answer_tokens': fields.Integer,
|
|
||||||
'provider_response_latency': fields.Float,
|
|
||||||
'from_source': fields.String,
|
|
||||||
'from_end_user_id': fields.String,
|
|
||||||
'from_account_id': fields.String,
|
|
||||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
|
||||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageListApi(Resource):
|
class ChatMessageListApi(Resource):
|
||||||
message_infinite_scroll_pagination_fields = {
|
message_infinite_scroll_pagination_fields = {
|
||||||
|
|
|
@ -8,26 +8,11 @@ from controllers.console import api
|
||||||
from controllers.console.app import _get_app
|
from controllers.console.app import _get_app
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
|
from fields.app_fields import app_site_fields
|
||||||
from libs.helper import supported_language
|
from libs.helper import supported_language
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import Site
|
from models.model import Site
|
||||||
|
|
||||||
app_site_fields = {
|
|
||||||
'app_id': fields.String,
|
|
||||||
'access_token': fields.String(attribute='code'),
|
|
||||||
'code': fields.String,
|
|
||||||
'title': fields.String,
|
|
||||||
'icon': fields.String,
|
|
||||||
'icon_background': fields.String,
|
|
||||||
'description': fields.String,
|
|
||||||
'default_language': fields.String,
|
|
||||||
'customize_domain': fields.String,
|
|
||||||
'copyright': fields.String,
|
|
||||||
'privacy_policy': fields.String,
|
|
||||||
'customize_token_strategy': fields.String,
|
|
||||||
'prompt_public': fields.Boolean
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_app_site_args():
|
def parse_app_site_args():
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
|
@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required
|
||||||
from core.data_loader.loader.notion import NotionLoader
|
from core.data_loader.loader.notion import NotionLoader
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from models.dataset import Document
|
from models.dataset import Document
|
||||||
from models.source import DataSourceBinding
|
from models.source import DataSourceBinding
|
||||||
|
@ -24,37 +25,6 @@ cache = TTLCache(maxsize=None, ttl=30)
|
||||||
|
|
||||||
|
|
||||||
class DataSourceApi(Resource):
|
class DataSourceApi(Resource):
|
||||||
integrate_icon_fields = {
|
|
||||||
'type': fields.String,
|
|
||||||
'url': fields.String,
|
|
||||||
'emoji': fields.String
|
|
||||||
}
|
|
||||||
integrate_page_fields = {
|
|
||||||
'page_name': fields.String,
|
|
||||||
'page_id': fields.String,
|
|
||||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
|
||||||
'parent_id': fields.String,
|
|
||||||
'type': fields.String
|
|
||||||
}
|
|
||||||
integrate_workspace_fields = {
|
|
||||||
'workspace_name': fields.String,
|
|
||||||
'workspace_id': fields.String,
|
|
||||||
'workspace_icon': fields.String,
|
|
||||||
'pages': fields.List(fields.Nested(integrate_page_fields)),
|
|
||||||
'total': fields.Integer
|
|
||||||
}
|
|
||||||
integrate_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'provider': fields.String,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
'is_bound': fields.Boolean,
|
|
||||||
'disabled': fields.Boolean,
|
|
||||||
'link': fields.String,
|
|
||||||
'source_info': fields.Nested(integrate_workspace_fields)
|
|
||||||
}
|
|
||||||
integrate_list_fields = {
|
|
||||||
'data': fields.List(fields.Nested(integrate_fields)),
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -131,28 +101,6 @@ class DataSourceApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class DataSourceNotionListApi(Resource):
|
class DataSourceNotionListApi(Resource):
|
||||||
integrate_icon_fields = {
|
|
||||||
'type': fields.String,
|
|
||||||
'url': fields.String,
|
|
||||||
'emoji': fields.String
|
|
||||||
}
|
|
||||||
integrate_page_fields = {
|
|
||||||
'page_name': fields.String,
|
|
||||||
'page_id': fields.String,
|
|
||||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
|
||||||
'is_bound': fields.Boolean,
|
|
||||||
'parent_id': fields.String,
|
|
||||||
'type': fields.String
|
|
||||||
}
|
|
||||||
integrate_workspace_fields = {
|
|
||||||
'workspace_name': fields.String,
|
|
||||||
'workspace_id': fields.String,
|
|
||||||
'workspace_icon': fields.String,
|
|
||||||
'pages': fields.List(fields.Nested(integrate_page_fields))
|
|
||||||
}
|
|
||||||
integrate_notion_info_list_fields = {
|
|
||||||
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
from flask import request
|
import flask_restful
|
||||||
|
from flask import request, current_app
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
|
||||||
|
from controllers.console.apikey import api_key_list, api_key_fields
|
||||||
from core.login.login import login_required
|
from core.login.login import login_required
|
||||||
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
||||||
from werkzeug.exceptions import NotFound, Forbidden
|
from werkzeug.exceptions import NotFound, Forbidden
|
||||||
|
@ -12,45 +15,16 @@ from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
from core.model_providers.model_factory import ModelFactory
|
|
||||||
from core.model_providers.models.entity.model_params import ModelType
|
from core.model_providers.models.entity.model_params import ModelType
|
||||||
from libs.helper import TimestampField
|
from fields.app_fields import related_app_list
|
||||||
|
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||||
|
from fields.document_fields import document_status_fields
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import DocumentSegment, Document
|
from models.dataset import DocumentSegment, Document
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile, ApiToken
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.provider_service import ProviderService
|
from services.provider_service import ProviderService
|
||||||
|
|
||||||
dataset_detail_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'description': fields.String,
|
|
||||||
'provider': fields.String,
|
|
||||||
'permission': fields.String,
|
|
||||||
'data_source_type': fields.String,
|
|
||||||
'indexing_technique': fields.String,
|
|
||||||
'app_count': fields.Integer,
|
|
||||||
'document_count': fields.Integer,
|
|
||||||
'word_count': fields.Integer,
|
|
||||||
'created_by': fields.String,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
'updated_by': fields.String,
|
|
||||||
'updated_at': TimestampField,
|
|
||||||
'embedding_model': fields.String,
|
|
||||||
'embedding_model_provider': fields.String,
|
|
||||||
'embedding_available': fields.Boolean
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset_query_detail_fields = {
|
|
||||||
"id": fields.String,
|
|
||||||
"content": fields.String,
|
|
||||||
"source": fields.String,
|
|
||||||
"source_app_id": fields.String,
|
|
||||||
"created_by_role": fields.String,
|
|
||||||
"created_by": fields.String,
|
|
||||||
"created_at": TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name):
|
||||||
if not name or len(name) < 1 or len(name) > 40:
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
|
@ -82,7 +56,8 @@ class DatasetListApi(Resource):
|
||||||
|
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_service = ProviderService()
|
provider_service = ProviderService()
|
||||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
|
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||||
|
ModelType.EMBEDDINGS.value)
|
||||||
# if len(valid_model_list) == 0:
|
# if len(valid_model_list) == 0:
|
||||||
# raise ProviderNotInitializeError(
|
# raise ProviderNotInitializeError(
|
||||||
# f"No Embedding Model available. Please configure a valid provider "
|
# f"No Embedding Model available. Please configure a valid provider "
|
||||||
|
@ -157,7 +132,8 @@ class DatasetApi(Resource):
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_service = ProviderService()
|
provider_service = ProviderService()
|
||||||
# get valid model list
|
# get valid model list
|
||||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
|
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||||
|
ModelType.EMBEDDINGS.value)
|
||||||
model_names = []
|
model_names = []
|
||||||
for valid_model in valid_model_list:
|
for valid_model in valid_model_list:
|
||||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||||
|
@ -271,7 +247,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||||
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
|
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
|
||||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||||
|
location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.estimate_args_validate(args)
|
DocumentService.estimate_args_validate(args)
|
||||||
|
@ -320,18 +297,6 @@ class DatasetIndexingEstimateApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class DatasetRelatedAppListApi(Resource):
|
class DatasetRelatedAppListApi(Resource):
|
||||||
app_detail_kernel_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'mode': fields.String,
|
|
||||||
'icon': fields.String,
|
|
||||||
'icon_background': fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
related_app_list = {
|
|
||||||
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
|
|
||||||
'total': fields.Integer,
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -363,24 +328,6 @@ class DatasetRelatedAppListApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class DatasetIndexingStatusApi(Resource):
|
class DatasetIndexingStatusApi(Resource):
|
||||||
document_status_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'indexing_status': fields.String,
|
|
||||||
'processing_started_at': TimestampField,
|
|
||||||
'parsing_completed_at': TimestampField,
|
|
||||||
'cleaning_completed_at': TimestampField,
|
|
||||||
'splitting_completed_at': TimestampField,
|
|
||||||
'completed_at': TimestampField,
|
|
||||||
'paused_at': TimestampField,
|
|
||||||
'error': fields.String,
|
|
||||||
'stopped_at': TimestampField,
|
|
||||||
'completed_segments': fields.Integer,
|
|
||||||
'total_segments': fields.Integer,
|
|
||||||
}
|
|
||||||
|
|
||||||
document_status_fields_list = {
|
|
||||||
'data': fields.List(fields.Nested(document_status_fields))
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -400,16 +347,97 @@ class DatasetIndexingStatusApi(Resource):
|
||||||
DocumentSegment.status != 're_segment').count()
|
DocumentSegment.status != 're_segment').count()
|
||||||
document.completed_segments = completed_segments
|
document.completed_segments = completed_segments
|
||||||
document.total_segments = total_segments
|
document.total_segments = total_segments
|
||||||
documents_status.append(marshal(document, self.document_status_fields))
|
documents_status.append(marshal(document, document_status_fields))
|
||||||
data = {
|
data = {
|
||||||
'data': documents_status
|
'data': documents_status
|
||||||
}
|
}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetApiKeyApi(Resource):
|
||||||
|
max_keys = 10
|
||||||
|
token_prefix = 'dataset-'
|
||||||
|
resource_type = 'dataset'
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(api_key_list)
|
||||||
|
def get(self):
|
||||||
|
keys = db.session.query(ApiToken). \
|
||||||
|
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
||||||
|
all()
|
||||||
|
return {"items": keys}
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(api_key_fields)
|
||||||
|
def post(self):
|
||||||
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
current_key_count = db.session.query(ApiToken). \
|
||||||
|
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
||||||
|
count()
|
||||||
|
|
||||||
|
if current_key_count >= self.max_keys:
|
||||||
|
flask_restful.abort(
|
||||||
|
400,
|
||||||
|
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||||
|
code='max_keys_exceeded'
|
||||||
|
)
|
||||||
|
|
||||||
|
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||||
|
api_token = ApiToken()
|
||||||
|
api_token.tenant_id = current_user.current_tenant_id
|
||||||
|
api_token.token = key
|
||||||
|
api_token.type = self.resource_type
|
||||||
|
db.session.add(api_token)
|
||||||
|
db.session.commit()
|
||||||
|
return api_token, 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def delete(self, api_key_id):
|
||||||
|
api_key_id = str(api_key_id)
|
||||||
|
|
||||||
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
key = db.session.query(ApiToken). \
|
||||||
|
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
|
||||||
|
ApiToken.id == api_key_id). \
|
||||||
|
first()
|
||||||
|
|
||||||
|
if key is None:
|
||||||
|
flask_restful.abort(404, message='API key not found')
|
||||||
|
|
||||||
|
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return {'result': 'success'}, 204
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetApiBaseUrlApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
return {
|
||||||
|
'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
|
||||||
|
else request.host_url.rstrip('/')) + '/v1'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DatasetListApi, '/datasets')
|
api.add_resource(DatasetListApi, '/datasets')
|
||||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
||||||
|
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
|
||||||
|
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||||
|
|
|
@ -23,6 +23,8 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
|
||||||
LLMBadRequestError
|
LLMBadRequestError
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_providers.model_factory import ModelFactory
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
from fields.document_fields import document_with_segments_fields, document_fields, \
|
||||||
|
dataset_and_document_fields, document_status_fields
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import DatasetProcessRule, Dataset
|
from models.dataset import DatasetProcessRule, Dataset
|
||||||
|
@ -32,64 +34,6 @@ from services.dataset_service import DocumentService, DatasetService
|
||||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||||
|
|
||||||
dataset_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'description': fields.String,
|
|
||||||
'permission': fields.String,
|
|
||||||
'data_source_type': fields.String,
|
|
||||||
'indexing_technique': fields.String,
|
|
||||||
'created_by': fields.String,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
}
|
|
||||||
|
|
||||||
document_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'position': fields.Integer,
|
|
||||||
'data_source_type': fields.String,
|
|
||||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
|
||||||
'dataset_process_rule_id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'created_from': fields.String,
|
|
||||||
'created_by': fields.String,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
'tokens': fields.Integer,
|
|
||||||
'indexing_status': fields.String,
|
|
||||||
'error': fields.String,
|
|
||||||
'enabled': fields.Boolean,
|
|
||||||
'disabled_at': TimestampField,
|
|
||||||
'disabled_by': fields.String,
|
|
||||||
'archived': fields.Boolean,
|
|
||||||
'display_status': fields.String,
|
|
||||||
'word_count': fields.Integer,
|
|
||||||
'hit_count': fields.Integer,
|
|
||||||
'doc_form': fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
document_with_segments_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'position': fields.Integer,
|
|
||||||
'data_source_type': fields.String,
|
|
||||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
|
||||||
'dataset_process_rule_id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'created_from': fields.String,
|
|
||||||
'created_by': fields.String,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
'tokens': fields.Integer,
|
|
||||||
'indexing_status': fields.String,
|
|
||||||
'error': fields.String,
|
|
||||||
'enabled': fields.Boolean,
|
|
||||||
'disabled_at': TimestampField,
|
|
||||||
'disabled_by': fields.String,
|
|
||||||
'archived': fields.Boolean,
|
|
||||||
'display_status': fields.String,
|
|
||||||
'word_count': fields.Integer,
|
|
||||||
'hit_count': fields.Integer,
|
|
||||||
'completed_segments': fields.Integer,
|
|
||||||
'total_segments': fields.Integer
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentResource(Resource):
|
class DocumentResource(Resource):
|
||||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||||
|
@ -303,11 +247,6 @@ class DatasetDocumentListApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
class DatasetInitApi(Resource):
|
class DatasetInitApi(Resource):
|
||||||
dataset_and_document_fields = {
|
|
||||||
'dataset': fields.Nested(dataset_fields),
|
|
||||||
'documents': fields.List(fields.Nested(document_fields)),
|
|
||||||
'batch': fields.String
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -504,24 +443,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||||
|
|
||||||
|
|
||||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||||
document_status_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'indexing_status': fields.String,
|
|
||||||
'processing_started_at': TimestampField,
|
|
||||||
'parsing_completed_at': TimestampField,
|
|
||||||
'cleaning_completed_at': TimestampField,
|
|
||||||
'splitting_completed_at': TimestampField,
|
|
||||||
'completed_at': TimestampField,
|
|
||||||
'paused_at': TimestampField,
|
|
||||||
'error': fields.String,
|
|
||||||
'stopped_at': TimestampField,
|
|
||||||
'completed_segments': fields.Integer,
|
|
||||||
'total_segments': fields.Integer,
|
|
||||||
}
|
|
||||||
|
|
||||||
document_status_fields_list = {
|
|
||||||
'data': fields.List(fields.Nested(document_status_fields))
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -541,7 +462,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||||
document.total_segments = total_segments
|
document.total_segments = total_segments
|
||||||
if document.is_paused:
|
if document.is_paused:
|
||||||
document.indexing_status = 'paused'
|
document.indexing_status = 'paused'
|
||||||
documents_status.append(marshal(document, self.document_status_fields))
|
documents_status.append(marshal(document, document_status_fields))
|
||||||
data = {
|
data = {
|
||||||
'data': documents_status
|
'data': documents_status
|
||||||
}
|
}
|
||||||
|
@ -549,20 +470,6 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||||
|
|
||||||
|
|
||||||
class DocumentIndexingStatusApi(DocumentResource):
|
class DocumentIndexingStatusApi(DocumentResource):
|
||||||
document_status_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'indexing_status': fields.String,
|
|
||||||
'processing_started_at': TimestampField,
|
|
||||||
'parsing_completed_at': TimestampField,
|
|
||||||
'cleaning_completed_at': TimestampField,
|
|
||||||
'splitting_completed_at': TimestampField,
|
|
||||||
'completed_at': TimestampField,
|
|
||||||
'paused_at': TimestampField,
|
|
||||||
'error': fields.String,
|
|
||||||
'stopped_at': TimestampField,
|
|
||||||
'completed_segments': fields.Integer,
|
|
||||||
'total_segments': fields.Integer,
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -586,7 +493,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||||
document.total_segments = total_segments
|
document.total_segments = total_segments
|
||||||
if document.is_paused:
|
if document.is_paused:
|
||||||
document.indexing_status = 'paused'
|
document.indexing_status = 'paused'
|
||||||
return marshal(document, self.document_status_fields)
|
return marshal(document, document_status_fields)
|
||||||
|
|
||||||
|
|
||||||
class DocumentDetailApi(DocumentResource):
|
class DocumentDetailApi(DocumentResource):
|
||||||
|
|
|
@ -3,7 +3,7 @@ import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, reqparse, fields, marshal
|
from flask_restful import Resource, reqparse, marshal
|
||||||
from werkzeug.exceptions import NotFound, Forbidden
|
from werkzeug.exceptions import NotFound, Forbidden
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
@ -17,6 +17,7 @@ from core.model_providers.model_factory import ModelFactory
|
||||||
from core.login.login import login_required
|
from core.login.login import login_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
from fields.segment_fields import segment_fields
|
||||||
from models.dataset import DocumentSegment
|
from models.dataset import DocumentSegment
|
||||||
|
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
@ -26,36 +27,6 @@ from tasks.disable_segment_from_index_task import disable_segment_from_index_tas
|
||||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
segment_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'position': fields.Integer,
|
|
||||||
'document_id': fields.String,
|
|
||||||
'content': fields.String,
|
|
||||||
'answer': fields.String,
|
|
||||||
'word_count': fields.Integer,
|
|
||||||
'tokens': fields.Integer,
|
|
||||||
'keywords': fields.List(fields.String),
|
|
||||||
'index_node_id': fields.String,
|
|
||||||
'index_node_hash': fields.String,
|
|
||||||
'hit_count': fields.Integer,
|
|
||||||
'enabled': fields.Boolean,
|
|
||||||
'disabled_at': TimestampField,
|
|
||||||
'disabled_by': fields.String,
|
|
||||||
'status': fields.String,
|
|
||||||
'created_by': fields.String,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
'indexing_at': TimestampField,
|
|
||||||
'completed_at': TimestampField,
|
|
||||||
'error': fields.String,
|
|
||||||
'stopped_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
segment_list_response = {
|
|
||||||
'data': fields.List(fields.Nested(segment_fields)),
|
|
||||||
'has_more': fields.Boolean,
|
|
||||||
'limit': fields.Integer
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetDocumentSegmentListApi(Resource):
|
class DatasetDocumentSegmentListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
|
@ -1,28 +1,19 @@
|
||||||
import datetime
|
|
||||||
import hashlib
|
|
||||||
import tempfile
|
|
||||||
import chardet
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from cachetools import TTLCache
|
from cachetools import TTLCache
|
||||||
from flask import request, current_app
|
from flask import request, current_app
|
||||||
from flask_login import current_user
|
|
||||||
|
import services
|
||||||
from core.login.login import login_required
|
from core.login.login import login_required
|
||||||
from flask_restful import Resource, marshal_with, fields
|
from flask_restful import Resource, marshal_with, fields
|
||||||
from werkzeug.exceptions import NotFound
|
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
||||||
UnsupportedFileTypeError
|
UnsupportedFileTypeError
|
||||||
|
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.data_loader.file_extractor import FileExtractor
|
from fields.file_fields import upload_config_fields, file_fields
|
||||||
from extensions.ext_storage import storage
|
|
||||||
from libs.helper import TimestampField
|
from services.file_service import FileService
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.model import UploadFile
|
|
||||||
|
|
||||||
cache = TTLCache(maxsize=None, ttl=30)
|
cache = TTLCache(maxsize=None, ttl=30)
|
||||||
|
|
||||||
|
@ -31,10 +22,6 @@ PREVIEW_WORDS_LIMIT = 3000
|
||||||
|
|
||||||
|
|
||||||
class FileApi(Resource):
|
class FileApi(Resource):
|
||||||
upload_config_fields = {
|
|
||||||
'file_size_limit': fields.Integer,
|
|
||||||
'batch_count_limit': fields.Integer
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -48,16 +35,6 @@ class FileApi(Resource):
|
||||||
'batch_count_limit': batch_count_limit
|
'batch_count_limit': batch_count_limit
|
||||||
}, 200
|
}, 200
|
||||||
|
|
||||||
file_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'size': fields.Integer,
|
|
||||||
'extension': fields.String,
|
|
||||||
'mime_type': fields.String,
|
|
||||||
'created_by': fields.String,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
}
|
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@ -73,45 +50,13 @@ class FileApi(Resource):
|
||||||
|
|
||||||
if len(request.files) > 1:
|
if len(request.files) > 1:
|
||||||
raise TooManyFilesError()
|
raise TooManyFilesError()
|
||||||
|
try:
|
||||||
file_content = file.read()
|
upload_file = FileService.upload_file(file)
|
||||||
file_size = len(file_content)
|
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||||
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
if file_size > file_size_limit:
|
|
||||||
message = "({file_size} > {file_size_limit})"
|
|
||||||
raise FileTooLargeError(message)
|
|
||||||
|
|
||||||
extension = file.filename.split('.')[-1]
|
|
||||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
# user uuid as file name
|
|
||||||
file_uuid = str(uuid.uuid4())
|
|
||||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
|
|
||||||
|
|
||||||
# save file to storage
|
|
||||||
storage.save(file_key, file_content)
|
|
||||||
|
|
||||||
# save file to db
|
|
||||||
config = current_app.config
|
|
||||||
upload_file = UploadFile(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
storage_type=config['STORAGE_TYPE'],
|
|
||||||
key=file_key,
|
|
||||||
name=file.filename,
|
|
||||||
size=file_size,
|
|
||||||
extension=extension,
|
|
||||||
mime_type=file.mimetype,
|
|
||||||
created_by=current_user.id,
|
|
||||||
created_at=datetime.datetime.utcnow(),
|
|
||||||
used=False,
|
|
||||||
hash=hashlib.sha3_256(file_content).hexdigest()
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(upload_file)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return upload_file, 201
|
return upload_file, 201
|
||||||
|
|
||||||
|
|
||||||
|
@ -121,26 +66,7 @@ class FilePreviewApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, file_id):
|
def get(self, file_id):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
|
text = FileService.get_file_preview(file_id)
|
||||||
key = file_id + request.path
|
|
||||||
cached_response = cache.get(key)
|
|
||||||
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
|
|
||||||
return cached_response['response']
|
|
||||||
|
|
||||||
upload_file = db.session.query(UploadFile) \
|
|
||||||
.filter(UploadFile.id == file_id) \
|
|
||||||
.first()
|
|
||||||
|
|
||||||
if not upload_file:
|
|
||||||
raise NotFound("File not found")
|
|
||||||
|
|
||||||
# extract text from file
|
|
||||||
extension = upload_file.extension
|
|
||||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
|
||||||
raise UnsupportedFileTypeError()
|
|
||||||
|
|
||||||
text = FileExtractor.load(upload_file, return_text=True)
|
|
||||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
|
||||||
return {'content': text}
|
return {'content': text}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from core.login.login import login_required
|
from core.login.login import login_required
|
||||||
from flask_restful import Resource, reqparse, marshal, fields
|
from flask_restful import Resource, reqparse, marshal
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
@ -14,48 +14,10 @@ from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||||
LLMBadRequestError
|
LLMBadRequestError
|
||||||
from libs.helper import TimestampField
|
from fields.hit_testing_fields import hit_testing_record_fields
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.hit_testing_service import HitTestingService
|
from services.hit_testing_service import HitTestingService
|
||||||
|
|
||||||
document_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'data_source_type': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'doc_type': fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
segment_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'position': fields.Integer,
|
|
||||||
'document_id': fields.String,
|
|
||||||
'content': fields.String,
|
|
||||||
'answer': fields.String,
|
|
||||||
'word_count': fields.Integer,
|
|
||||||
'tokens': fields.Integer,
|
|
||||||
'keywords': fields.List(fields.String),
|
|
||||||
'index_node_id': fields.String,
|
|
||||||
'index_node_hash': fields.String,
|
|
||||||
'hit_count': fields.Integer,
|
|
||||||
'enabled': fields.Boolean,
|
|
||||||
'disabled_at': TimestampField,
|
|
||||||
'disabled_by': fields.String,
|
|
||||||
'status': fields.String,
|
|
||||||
'created_by': fields.String,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
'indexing_at': TimestampField,
|
|
||||||
'completed_at': TimestampField,
|
|
||||||
'error': fields.String,
|
|
||||||
'stopped_at': TimestampField,
|
|
||||||
'document': fields.Nested(document_fields),
|
|
||||||
}
|
|
||||||
|
|
||||||
hit_testing_record_fields = {
|
|
||||||
'segment': fields.Nested(segment_fields),
|
|
||||||
'score': fields.Float,
|
|
||||||
'tsne_position': fields.Raw
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class HitTestingApi(Resource):
|
class HitTestingApi(Resource):
|
||||||
|
|
||||||
|
|
|
@ -7,26 +7,12 @@ from werkzeug.exceptions import NotFound
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.explore.error import NotChatAppError
|
from controllers.console.explore.error import NotChatAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
|
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||||
from services.web_conversation_service import WebConversationService
|
from services.web_conversation_service import WebConversationService
|
||||||
|
|
||||||
conversation_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'inputs': fields.Raw,
|
|
||||||
'status': fields.String,
|
|
||||||
'introduction': fields.String,
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation_infinite_scroll_pagination_fields = {
|
|
||||||
'limit': fields.Integer,
|
|
||||||
'has_more': fields.Boolean,
|
|
||||||
'data': fields.List(fields.Nested(conversation_fields))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationListApi(InstalledAppResource):
|
class ConversationListApi(InstalledAppResource):
|
||||||
|
|
||||||
|
@ -76,7 +62,7 @@ class ConversationApi(InstalledAppResource):
|
||||||
|
|
||||||
class ConversationRenameApi(InstalledAppResource):
|
class ConversationRenameApi(InstalledAppResource):
|
||||||
|
|
||||||
@marshal_with(conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
def post(self, installed_app, c_id):
|
def post(self, installed_app, c_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != 'chat':
|
if app_model.mode != 'chat':
|
||||||
|
|
|
@ -11,32 +11,11 @@ from controllers.console import api
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from fields.installed_app_fields import installed_app_list_fields
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from models.model import App, InstalledApp, RecommendedApp
|
from models.model import App, InstalledApp, RecommendedApp
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
|
|
||||||
app_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'mode': fields.String,
|
|
||||||
'icon': fields.String,
|
|
||||||
'icon_background': fields.String
|
|
||||||
}
|
|
||||||
|
|
||||||
installed_app_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'app': fields.Nested(app_fields),
|
|
||||||
'app_owner_tenant_id': fields.String,
|
|
||||||
'is_pinned': fields.Boolean,
|
|
||||||
'last_used_at': TimestampField,
|
|
||||||
'editable': fields.Boolean,
|
|
||||||
'uninstallable': fields.Boolean,
|
|
||||||
}
|
|
||||||
|
|
||||||
installed_app_list_fields = {
|
|
||||||
'installed_apps': fields.List(fields.Nested(installed_app_fields))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class InstalledAppsListApi(Resource):
|
class InstalledAppsListApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
|
|
|
@ -17,6 +17,7 @@ from controllers.console.explore.error import NotCompletionAppError, AppSuggeste
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
|
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||||
from libs.helper import uuid_value, TimestampField
|
from libs.helper import uuid_value, TimestampField
|
||||||
from services.completion_service import CompletionService
|
from services.completion_service import CompletionService
|
||||||
from services.errors.app import MoreLikeThisDisabledError
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
|
@ -26,45 +27,6 @@ from services.message_service import MessageService
|
||||||
|
|
||||||
|
|
||||||
class MessageListApi(InstalledAppResource):
|
class MessageListApi(InstalledAppResource):
|
||||||
feedback_fields = {
|
|
||||||
'rating': fields.String
|
|
||||||
}
|
|
||||||
|
|
||||||
retriever_resource_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'message_id': fields.String,
|
|
||||||
'position': fields.Integer,
|
|
||||||
'dataset_id': fields.String,
|
|
||||||
'dataset_name': fields.String,
|
|
||||||
'document_id': fields.String,
|
|
||||||
'document_name': fields.String,
|
|
||||||
'data_source_type': fields.String,
|
|
||||||
'segment_id': fields.String,
|
|
||||||
'score': fields.Float,
|
|
||||||
'hit_count': fields.Integer,
|
|
||||||
'word_count': fields.Integer,
|
|
||||||
'segment_position': fields.Integer,
|
|
||||||
'index_node_hash': fields.String,
|
|
||||||
'content': fields.String,
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
message_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'conversation_id': fields.String,
|
|
||||||
'inputs': fields.Raw,
|
|
||||||
'query': fields.String,
|
|
||||||
'answer': fields.String,
|
|
||||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
|
||||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
message_infinite_scroll_pagination_fields = {
|
|
||||||
'limit': fields.Integer,
|
|
||||||
'has_more': fields.Boolean,
|
|
||||||
'data': fields.List(fields.Nested(message_fields))
|
|
||||||
}
|
|
||||||
|
|
||||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||||
def get(self, installed_app):
|
def get(self, installed_app):
|
||||||
|
|
|
@ -6,31 +6,17 @@ from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||||
|
from fields.conversation_fields import conversation_with_model_config_infinite_scroll_pagination_fields, \
|
||||||
|
conversation_with_model_config_fields
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||||
from services.web_conversation_service import WebConversationService
|
from services.web_conversation_service import WebConversationService
|
||||||
|
|
||||||
conversation_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'inputs': fields.Raw,
|
|
||||||
'status': fields.String,
|
|
||||||
'introduction': fields.String,
|
|
||||||
'created_at': TimestampField,
|
|
||||||
'model_config': fields.Raw,
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation_infinite_scroll_pagination_fields = {
|
|
||||||
'limit': fields.Integer,
|
|
||||||
'has_more': fields.Boolean,
|
|
||||||
'data': fields.List(fields.Nested(conversation_fields))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class UniversalChatConversationListApi(UniversalChatResource):
|
class UniversalChatConversationListApi(UniversalChatResource):
|
||||||
|
|
||||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
@marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields)
|
||||||
def get(self, universal_app):
|
def get(self, universal_app):
|
||||||
app_model = universal_app
|
app_model = universal_app
|
||||||
|
|
||||||
|
@ -73,7 +59,7 @@ class UniversalChatConversationApi(UniversalChatResource):
|
||||||
|
|
||||||
class UniversalChatConversationRenameApi(UniversalChatResource):
|
class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||||
|
|
||||||
@marshal_with(conversation_fields)
|
@marshal_with(conversation_with_model_config_fields)
|
||||||
def post(self, universal_app, c_id):
|
def post(self, universal_app, c_id):
|
||||||
app_model = universal_app
|
app_model = universal_app
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
|
@ -9,4 +9,4 @@ api = ExternalApi(bp)
|
||||||
|
|
||||||
from .app import completion, app, conversation, message, audio
|
from .app import completion, app, conversation, message, audio
|
||||||
|
|
||||||
from .dataset import document
|
from .dataset import document, segment, dataset
|
||||||
|
|
|
@ -8,25 +8,11 @@ from controllers.service_api import api
|
||||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||||
from controllers.service_api.app.error import NotChatAppError
|
from controllers.service_api.app.error import NotChatAppError
|
||||||
from controllers.service_api.wraps import AppApiResource
|
from controllers.service_api.wraps import AppApiResource
|
||||||
|
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
import services
|
import services
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
|
|
||||||
conversation_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'inputs': fields.Raw,
|
|
||||||
'status': fields.String,
|
|
||||||
'introduction': fields.String,
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation_infinite_scroll_pagination_fields = {
|
|
||||||
'limit': fields.Integer,
|
|
||||||
'has_more': fields.Boolean,
|
|
||||||
'data': fields.List(fields.Nested(conversation_fields))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationApi(AppApiResource):
|
class ConversationApi(AppApiResource):
|
||||||
|
|
||||||
|
@ -50,7 +36,7 @@ class ConversationApi(AppApiResource):
|
||||||
raise NotFound("Last Conversation Not Exists.")
|
raise NotFound("Last Conversation Not Exists.")
|
||||||
|
|
||||||
class ConversationDetailApi(AppApiResource):
|
class ConversationDetailApi(AppApiResource):
|
||||||
@marshal_with(conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
def delete(self, app_model, end_user, c_id):
|
def delete(self, app_model, end_user, c_id):
|
||||||
if app_model.mode != 'chat':
|
if app_model.mode != 'chat':
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
@ -70,7 +56,7 @@ class ConversationDetailApi(AppApiResource):
|
||||||
|
|
||||||
class ConversationRenameApi(AppApiResource):
|
class ConversationRenameApi(AppApiResource):
|
||||||
|
|
||||||
@marshal_with(conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
def post(self, app_model, end_user, c_id):
|
def post(self, app_model, end_user, c_id):
|
||||||
if app_model.mode != 'chat':
|
if app_model.mode != 'chat':
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
84
api/controllers/service_api/dataset/dataset.py
Normal file
84
api/controllers/service_api/dataset/dataset.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
from flask import request
|
||||||
|
from flask_restful import reqparse, marshal
|
||||||
|
import services.dataset_service
|
||||||
|
from controllers.service_api import api
|
||||||
|
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
||||||
|
from controllers.service_api.wraps import DatasetApiResource
|
||||||
|
from core.login.login import current_user
|
||||||
|
from core.model_providers.models.entity.model_params import ModelType
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
|
from models.account import Account, TenantAccountJoin
|
||||||
|
from models.dataset import Dataset
|
||||||
|
from services.dataset_service import DatasetService
|
||||||
|
from services.provider_service import ProviderService
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_name(name):
|
||||||
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
|
raise ValueError('Name must be between 1 to 40 characters.')
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetApi(DatasetApiResource):
|
||||||
|
"""Resource for get datasets."""
|
||||||
|
|
||||||
|
def get(self, tenant_id):
|
||||||
|
page = request.args.get('page', default=1, type=int)
|
||||||
|
limit = request.args.get('limit', default=20, type=int)
|
||||||
|
provider = request.args.get('provider', default="vendor")
|
||||||
|
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||||
|
tenant_id, current_user)
|
||||||
|
# check embedding setting
|
||||||
|
provider_service = ProviderService()
|
||||||
|
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||||
|
ModelType.EMBEDDINGS.value)
|
||||||
|
model_names = []
|
||||||
|
for valid_model in valid_model_list:
|
||||||
|
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||||
|
data = marshal(datasets, dataset_detail_fields)
|
||||||
|
for item in data:
|
||||||
|
if item['indexing_technique'] == 'high_quality':
|
||||||
|
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||||
|
if item_model in model_names:
|
||||||
|
item['embedding_available'] = True
|
||||||
|
else:
|
||||||
|
item['embedding_available'] = False
|
||||||
|
else:
|
||||||
|
item['embedding_available'] = True
|
||||||
|
response = {
|
||||||
|
'data': data,
|
||||||
|
'has_more': len(datasets) == limit,
|
||||||
|
'limit': limit,
|
||||||
|
'total': total,
|
||||||
|
'page': page
|
||||||
|
}
|
||||||
|
return response, 200
|
||||||
|
|
||||||
|
"""Resource for datasets."""
|
||||||
|
|
||||||
|
def post(self, tenant_id):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('name', nullable=False, required=True,
|
||||||
|
help='type is required. Name must be between 1 to 40 characters.',
|
||||||
|
type=_validate_name)
|
||||||
|
parser.add_argument('indexing_technique', type=str, location='json',
|
||||||
|
choices=('high_quality', 'economy'),
|
||||||
|
help='Invalid indexing technique.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset = DatasetService.create_empty_dataset(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
name=args['name'],
|
||||||
|
indexing_technique=args['indexing_technique'],
|
||||||
|
account=current_user
|
||||||
|
)
|
||||||
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
|
raise DatasetNameDuplicateError()
|
||||||
|
|
||||||
|
return marshal(dataset, dataset_detail_fields), 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(DatasetApi, '/datasets')
|
||||||
|
|
|
@ -1,114 +1,291 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app, request
|
||||||
from flask_restful import reqparse
|
from flask_restful import reqparse, marshal
|
||||||
|
from sqlalchemy import desc
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
import services.dataset_service
|
import services.dataset_service
|
||||||
from controllers.service_api import api
|
from controllers.service_api import api
|
||||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||||
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
|
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
|
||||||
DatasetNotInitedError
|
NoFileUploadedError, TooManyFilesError
|
||||||
from controllers.service_api.wraps import DatasetApiResource
|
from controllers.service_api.wraps import DatasetApiResource
|
||||||
|
from core.login.login import current_user
|
||||||
from core.model_providers.error import ProviderTokenNotInitError
|
from core.model_providers.error import ProviderTokenNotInitError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
|
from fields.document_fields import document_fields, document_status_fields
|
||||||
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
from services.dataset_service import DocumentService
|
from services.dataset_service import DocumentService
|
||||||
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
class DocumentListApi(DatasetApiResource):
|
class DocumentAddByTextApi(DatasetApiResource):
|
||||||
"""Resource for documents."""
|
"""Resource for documents."""
|
||||||
|
|
||||||
def post(self, dataset):
|
def post(self, tenant_id, dataset_id):
|
||||||
"""Create document."""
|
"""Create document by text."""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||||
parser.add_argument('doc_type', type=str, location='json')
|
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
|
||||||
parser.add_argument('doc_metadata', type=dict, location='json')
|
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||||
|
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||||
|
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||||
|
location='json')
|
||||||
|
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
|
||||||
|
location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
tenant_id = str(tenant_id)
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
if not dataset.indexing_technique:
|
if not dataset:
|
||||||
raise DatasetNotInitedError("Dataset indexing technique must be set.")
|
raise ValueError('Dataset is not exist.')
|
||||||
|
|
||||||
doc_type = args.get('doc_type')
|
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||||
doc_metadata = args.get('doc_metadata')
|
raise ValueError('indexing_technique is required.')
|
||||||
|
|
||||||
if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
|
||||||
raise ValueError('Invalid doc_type.')
|
data_source = {
|
||||||
|
'type': 'upload_file',
|
||||||
# user uuid as file name
|
'info_list': {
|
||||||
file_uuid = str(uuid.uuid4())
|
'data_source_type': 'upload_file',
|
||||||
file_key = 'upload_files/' + dataset.tenant_id + '/' + file_uuid + '.txt'
|
'file_info_list': {
|
||||||
|
'file_ids': [upload_file.id]
|
||||||
# save file to storage
|
}
|
||||||
storage.save(file_key, args.get('text'))
|
|
||||||
|
|
||||||
# save file to db
|
|
||||||
config = current_app.config
|
|
||||||
upload_file = UploadFile(
|
|
||||||
tenant_id=dataset.tenant_id,
|
|
||||||
storage_type=config['STORAGE_TYPE'],
|
|
||||||
key=file_key,
|
|
||||||
name=args.get('name') + '.txt',
|
|
||||||
size=len(args.get('text')),
|
|
||||||
extension='txt',
|
|
||||||
mime_type='text/plain',
|
|
||||||
created_by=dataset.created_by,
|
|
||||||
created_at=datetime.datetime.utcnow(),
|
|
||||||
used=True,
|
|
||||||
used_by=dataset.created_by,
|
|
||||||
used_at=datetime.datetime.utcnow()
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(upload_file)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
document_data = {
|
|
||||||
'data_source': {
|
|
||||||
'type': 'upload_file',
|
|
||||||
'info': [
|
|
||||||
{
|
|
||||||
'upload_file_id': upload_file.id
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
args['data_source'] = data_source
|
||||||
|
# validate args
|
||||||
|
DocumentService.document_create_args_validate(args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document_data=document_data,
|
document_data=args,
|
||||||
account=dataset.created_by_account,
|
account=current_user,
|
||||||
dataset_process_rule=dataset.latest_process_rule,
|
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||||
created_from='api'
|
created_from='api'
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
document = documents[0]
|
document = documents[0]
|
||||||
if doc_type and doc_metadata:
|
|
||||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
|
||||||
|
|
||||||
document.doc_metadata = {}
|
documents_and_batch_fields = {
|
||||||
|
'document': marshal(document, document_fields),
|
||||||
for key, value_type in metadata_schema.items():
|
'batch': batch
|
||||||
value = doc_metadata.get(key)
|
}
|
||||||
if value is not None and isinstance(value, value_type):
|
return documents_and_batch_fields, 200
|
||||||
document.doc_metadata[key] = value
|
|
||||||
|
|
||||||
document.doc_type = doc_type
|
|
||||||
document.updated_at = datetime.datetime.utcnow()
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return {'id': document.id}
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentApi(DatasetApiResource):
|
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||||
def delete(self, dataset, document_id):
|
"""Resource for update documents."""
|
||||||
|
|
||||||
|
def post(self, tenant_id, dataset_id, document_id):
|
||||||
|
"""Update document by text."""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
|
||||||
|
parser.add_argument('text', type=str, required=False, nullable=True, location='json')
|
||||||
|
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
|
||||||
|
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||||
|
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||||
|
location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
tenant_id = str(tenant_id)
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError('Dataset is not exist.')
|
||||||
|
|
||||||
|
if args['text']:
|
||||||
|
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
|
||||||
|
data_source = {
|
||||||
|
'type': 'upload_file',
|
||||||
|
'info_list': {
|
||||||
|
'data_source_type': 'upload_file',
|
||||||
|
'file_info_list': {
|
||||||
|
'file_ids': [upload_file.id]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
args['data_source'] = data_source
|
||||||
|
# validate args
|
||||||
|
args['original_document_id'] = str(document_id)
|
||||||
|
DocumentService.document_create_args_validate(args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
|
dataset=dataset,
|
||||||
|
document_data=args,
|
||||||
|
account=current_user,
|
||||||
|
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||||
|
created_from='api'
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
document = documents[0]
|
||||||
|
|
||||||
|
documents_and_batch_fields = {
|
||||||
|
'document': marshal(document, document_fields),
|
||||||
|
'batch': batch
|
||||||
|
}
|
||||||
|
return documents_and_batch_fields, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentAddByFileApi(DatasetApiResource):
|
||||||
|
"""Resource for documents."""
|
||||||
|
def post(self, tenant_id, dataset_id):
|
||||||
|
"""Create document by upload file."""
|
||||||
|
args = {}
|
||||||
|
if 'data' in request.form:
|
||||||
|
args = json.loads(request.form['data'])
|
||||||
|
if 'doc_form' not in args:
|
||||||
|
args['doc_form'] = 'text_model'
|
||||||
|
if 'doc_language' not in args:
|
||||||
|
args['doc_language'] = 'English'
|
||||||
|
# get dataset info
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
tenant_id = str(tenant_id)
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError('Dataset is not exist.')
|
||||||
|
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||||
|
raise ValueError('indexing_technique is required.')
|
||||||
|
|
||||||
|
# save file info
|
||||||
|
file = request.files['file']
|
||||||
|
# check file
|
||||||
|
if 'file' not in request.files:
|
||||||
|
raise NoFileUploadedError()
|
||||||
|
|
||||||
|
if len(request.files) > 1:
|
||||||
|
raise TooManyFilesError()
|
||||||
|
|
||||||
|
upload_file = FileService.upload_file(file)
|
||||||
|
data_source = {
|
||||||
|
'type': 'upload_file',
|
||||||
|
'info_list': {
|
||||||
|
'file_info_list': {
|
||||||
|
'file_ids': [upload_file.id]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
args['data_source'] = data_source
|
||||||
|
# validate args
|
||||||
|
DocumentService.document_create_args_validate(args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
|
dataset=dataset,
|
||||||
|
document_data=args,
|
||||||
|
account=dataset.created_by_account,
|
||||||
|
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||||
|
created_from='api'
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
document = documents[0]
|
||||||
|
documents_and_batch_fields = {
|
||||||
|
'document': marshal(document, document_fields),
|
||||||
|
'batch': batch
|
||||||
|
}
|
||||||
|
return documents_and_batch_fields, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||||
|
"""Resource for update documents."""
|
||||||
|
|
||||||
|
def post(self, tenant_id, dataset_id, document_id):
|
||||||
|
"""Update document by upload file."""
|
||||||
|
args = {}
|
||||||
|
if 'data' in request.form:
|
||||||
|
args = json.loads(request.form['data'])
|
||||||
|
if 'doc_form' not in args:
|
||||||
|
args['doc_form'] = 'text_model'
|
||||||
|
if 'doc_language' not in args:
|
||||||
|
args['doc_language'] = 'English'
|
||||||
|
|
||||||
|
# get dataset info
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
tenant_id = str(tenant_id)
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError('Dataset is not exist.')
|
||||||
|
if 'file' in request.files:
|
||||||
|
# save file info
|
||||||
|
file = request.files['file']
|
||||||
|
|
||||||
|
|
||||||
|
if len(request.files) > 1:
|
||||||
|
raise TooManyFilesError()
|
||||||
|
|
||||||
|
upload_file = FileService.upload_file(file)
|
||||||
|
data_source = {
|
||||||
|
'type': 'upload_file',
|
||||||
|
'info_list': {
|
||||||
|
'file_info_list': {
|
||||||
|
'file_ids': [upload_file.id]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
args['data_source'] = data_source
|
||||||
|
# validate args
|
||||||
|
args['original_document_id'] = str(document_id)
|
||||||
|
DocumentService.document_create_args_validate(args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
|
dataset=dataset,
|
||||||
|
document_data=args,
|
||||||
|
account=dataset.created_by_account,
|
||||||
|
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||||
|
created_from='api'
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
document = documents[0]
|
||||||
|
documents_and_batch_fields = {
|
||||||
|
'document': marshal(document, document_fields),
|
||||||
|
'batch': batch
|
||||||
|
}
|
||||||
|
return documents_and_batch_fields, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentDeleteApi(DatasetApiResource):
|
||||||
|
def delete(self, tenant_id, dataset_id, document_id):
|
||||||
"""Delete document."""
|
"""Delete document."""
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
tenant_id = str(tenant_id)
|
||||||
|
|
||||||
|
# get dataset info
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError('Dataset is not exist.')
|
||||||
|
|
||||||
document = DocumentService.get_document(dataset.id, document_id)
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
|
|
||||||
|
@ -126,8 +303,85 @@ class DocumentApi(DatasetApiResource):
|
||||||
except services.errors.document.DocumentIndexingError:
|
except services.errors.document.DocumentIndexingError:
|
||||||
raise DocumentIndexingError('Cannot delete document during indexing.')
|
raise DocumentIndexingError('Cannot delete document during indexing.')
|
||||||
|
|
||||||
return {'result': 'success'}, 204
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DocumentListApi, '/documents')
|
class DocumentListApi(DatasetApiResource):
|
||||||
api.add_resource(DocumentApi, '/documents/<uuid:document_id>')
|
def get(self, tenant_id, dataset_id):
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
tenant_id = str(tenant_id)
|
||||||
|
page = request.args.get('page', default=1, type=int)
|
||||||
|
limit = request.args.get('limit', default=20, type=int)
|
||||||
|
search = request.args.get('keyword', default=None, type=str)
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
if not dataset:
|
||||||
|
raise NotFound('Dataset not found.')
|
||||||
|
|
||||||
|
query = Document.query.filter_by(
|
||||||
|
dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||||
|
|
||||||
|
if search:
|
||||||
|
search = f'%{search}%'
|
||||||
|
query = query.filter(Document.name.like(search))
|
||||||
|
|
||||||
|
query = query.order_by(desc(Document.created_at))
|
||||||
|
|
||||||
|
paginated_documents = query.paginate(
|
||||||
|
page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||||
|
documents = paginated_documents.items
|
||||||
|
|
||||||
|
response = {
|
||||||
|
'data': marshal(documents, document_fields),
|
||||||
|
'has_more': len(documents) == limit,
|
||||||
|
'limit': limit,
|
||||||
|
'total': paginated_documents.total,
|
||||||
|
'page': page
|
||||||
|
}
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentIndexingStatusApi(DatasetApiResource):
|
||||||
|
def get(self, tenant_id, dataset_id, batch):
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
batch = str(batch)
|
||||||
|
tenant_id = str(tenant_id)
|
||||||
|
# get dataset
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
if not dataset:
|
||||||
|
raise NotFound('Dataset not found.')
|
||||||
|
# get documents
|
||||||
|
documents = DocumentService.get_batch_documents(dataset_id, batch)
|
||||||
|
if not documents:
|
||||||
|
raise NotFound('Documents not found.')
|
||||||
|
documents_status = []
|
||||||
|
for document in documents:
|
||||||
|
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||||
|
DocumentSegment.document_id == str(document.id),
|
||||||
|
DocumentSegment.status != 're_segment').count()
|
||||||
|
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
||||||
|
DocumentSegment.status != 're_segment').count()
|
||||||
|
document.completed_segments = completed_segments
|
||||||
|
document.total_segments = total_segments
|
||||||
|
if document.is_paused:
|
||||||
|
document.indexing_status = 'paused'
|
||||||
|
documents_status.append(marshal(document, document_status_fields))
|
||||||
|
data = {
|
||||||
|
'data': documents_status
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text')
|
||||||
|
api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file')
|
||||||
|
api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text')
|
||||||
|
api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file')
|
||||||
|
api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
|
||||||
|
api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents')
|
||||||
|
api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status')
|
||||||
|
|
|
@ -1,20 +1,73 @@
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
from libs.exception import BaseHTTPException
|
from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
|
class NoFileUploadedError(BaseHTTPException):
|
||||||
|
error_code = 'no_file_uploaded'
|
||||||
|
description = "Please upload your file."
|
||||||
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
|
class TooManyFilesError(BaseHTTPException):
|
||||||
|
error_code = 'too_many_files'
|
||||||
|
description = "Only one file is allowed."
|
||||||
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
|
class FileTooLargeError(BaseHTTPException):
|
||||||
|
error_code = 'file_too_large'
|
||||||
|
description = "File size exceeded. {message}"
|
||||||
|
code = 413
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedFileTypeError(BaseHTTPException):
|
||||||
|
error_code = 'unsupported_file_type'
|
||||||
|
description = "File type not allowed."
|
||||||
|
code = 415
|
||||||
|
|
||||||
|
|
||||||
|
class HighQualityDatasetOnlyError(BaseHTTPException):
|
||||||
|
error_code = 'high_quality_dataset_only'
|
||||||
|
description = "Current operation only supports 'high-quality' datasets."
|
||||||
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetNotInitializedError(BaseHTTPException):
|
||||||
|
error_code = 'dataset_not_initialized'
|
||||||
|
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||||
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class ArchivedDocumentImmutableError(BaseHTTPException):
|
class ArchivedDocumentImmutableError(BaseHTTPException):
|
||||||
error_code = 'archived_document_immutable'
|
error_code = 'archived_document_immutable'
|
||||||
description = "Cannot operate when document was archived."
|
description = "The archived document is not editable."
|
||||||
code = 403
|
code = 403
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetNameDuplicateError(BaseHTTPException):
|
||||||
|
error_code = 'dataset_name_duplicate'
|
||||||
|
description = "The dataset name already exists. Please modify your dataset name."
|
||||||
|
code = 409
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidActionError(BaseHTTPException):
|
||||||
|
error_code = 'invalid_action'
|
||||||
|
description = "Invalid action."
|
||||||
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentAlreadyFinishedError(BaseHTTPException):
|
||||||
|
error_code = 'document_already_finished'
|
||||||
|
description = "The document has been processed. Please refresh the page or go to the document details."
|
||||||
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DocumentIndexingError(BaseHTTPException):
|
class DocumentIndexingError(BaseHTTPException):
|
||||||
error_code = 'document_indexing'
|
error_code = 'document_indexing'
|
||||||
description = "Cannot operate document during indexing."
|
description = "The document is being processed and cannot be edited."
|
||||||
code = 403
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DatasetNotInitedError(BaseHTTPException):
|
class InvalidMetadataError(BaseHTTPException):
|
||||||
error_code = 'dataset_not_inited'
|
error_code = 'invalid_metadata'
|
||||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
description = "The metadata content is incorrect. Please check and verify."
|
||||||
code = 403
|
code = 400
|
||||||
|
|
59
api/controllers/service_api/dataset/segment.py
Normal file
59
api/controllers/service_api/dataset/segment.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask_restful import reqparse, marshal
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.service_api import api
|
||||||
|
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||||
|
from controllers.service_api.wraps import DatasetApiResource
|
||||||
|
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||||
|
from core.model_providers.model_factory import ModelFactory
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from fields.segment_fields import segment_fields
|
||||||
|
from models.dataset import Dataset
|
||||||
|
from services.dataset_service import DocumentService, SegmentService
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentApi(DatasetApiResource):
|
||||||
|
"""Resource for segments."""
|
||||||
|
def post(self, tenant_id, dataset_id, document_id):
|
||||||
|
"""Create single segment."""
|
||||||
|
# check dataset
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
tenant_id = str(tenant_id)
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
# check document
|
||||||
|
document_id = str(document_id)
|
||||||
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
|
if not document:
|
||||||
|
raise NotFound('Document not found.')
|
||||||
|
# check embedding model setting
|
||||||
|
if dataset.indexing_technique == 'high_quality':
|
||||||
|
try:
|
||||||
|
ModelFactory.get_embedding_model(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
|
model_name=dataset.embedding_model
|
||||||
|
)
|
||||||
|
except LLMBadRequestError:
|
||||||
|
raise ProviderNotInitializeError(
|
||||||
|
f"No Embedding Model available. Please configure a valid provider "
|
||||||
|
f"in the Settings -> Model Provider.")
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
# validate args
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
for args_item in args['segments']:
|
||||||
|
SegmentService.segment_create_args_validate(args_item, document)
|
||||||
|
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
|
||||||
|
return {
|
||||||
|
'data': marshal(segments, segment_fields),
|
||||||
|
'doc_form': document.doc_form
|
||||||
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
|
@ -2,11 +2,14 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from flask import request
|
from flask import request, current_app
|
||||||
|
from flask_login import user_logged_in
|
||||||
from flask_restful import Resource
|
from flask_restful import Resource
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
|
from core.login.login import _get_user
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models.account import Tenant, TenantAccountJoin, Account
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from models.model import ApiToken, App
|
from models.model import ApiToken, App
|
||||||
|
|
||||||
|
@ -43,12 +46,24 @@ def validate_dataset_token(view=None):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
api_token = validate_and_get_api_token('dataset')
|
api_token = validate_and_get_api_token('dataset')
|
||||||
|
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
||||||
dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first()
|
.filter(Tenant.id == api_token.tenant_id) \
|
||||||
if not dataset:
|
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||||
raise NotFound()
|
.filter(TenantAccountJoin.role == 'owner') \
|
||||||
|
.one_or_none()
|
||||||
return view(dataset, *args, **kwargs)
|
if tenant_account_join:
|
||||||
|
tenant, ta = tenant_account_join
|
||||||
|
account = Account.query.filter_by(id=ta.account_id).first()
|
||||||
|
# Login admin
|
||||||
|
if account:
|
||||||
|
account.current_tenant = tenant
|
||||||
|
current_app.login_manager._update_request_context_with_user(account)
|
||||||
|
user_logged_in.send(current_app._get_current_object(), user=_get_user())
|
||||||
|
else:
|
||||||
|
raise Unauthorized("Tenant owner account is not exist.")
|
||||||
|
else:
|
||||||
|
raise Unauthorized("Tenant is not exist.")
|
||||||
|
return view(api_token.tenant_id, *args, **kwargs)
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
if view:
|
if view:
|
||||||
|
|
|
@ -6,26 +6,12 @@ from werkzeug.exceptions import NotFound
|
||||||
from controllers.web import api
|
from controllers.web import api
|
||||||
from controllers.web.error import NotChatAppError
|
from controllers.web.error import NotChatAppError
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
|
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||||
from services.web_conversation_service import WebConversationService
|
from services.web_conversation_service import WebConversationService
|
||||||
|
|
||||||
conversation_fields = {
|
|
||||||
'id': fields.String,
|
|
||||||
'name': fields.String,
|
|
||||||
'inputs': fields.Raw,
|
|
||||||
'status': fields.String,
|
|
||||||
'introduction': fields.String,
|
|
||||||
'created_at': TimestampField
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation_infinite_scroll_pagination_fields = {
|
|
||||||
'limit': fields.Integer,
|
|
||||||
'has_more': fields.Boolean,
|
|
||||||
'data': fields.List(fields.Nested(conversation_fields))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationListApi(WebApiResource):
|
class ConversationListApi(WebApiResource):
|
||||||
|
|
||||||
|
@ -73,7 +59,7 @@ class ConversationApi(WebApiResource):
|
||||||
|
|
||||||
class ConversationRenameApi(WebApiResource):
|
class ConversationRenameApi(WebApiResource):
|
||||||
|
|
||||||
@marshal_with(conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
def post(self, app_model, end_user, c_id):
|
def post(self, app_model, end_user, c_id):
|
||||||
if app_model.mode != 'chat':
|
if app_model.mode != 'chat':
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
|
@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
|
||||||
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
||||||
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
|
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
|
||||||
SEARCH_URL = "https://api.notion.com/v1/search"
|
SEARCH_URL = "https://api.notion.com/v1/search"
|
||||||
|
|
||||||
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
||||||
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
||||||
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
|
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
|
||||||
|
|
|
@ -246,11 +246,28 @@ class KeywordTableIndex(BaseIndex):
|
||||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||||
self._save_dataset_keyword_table(keyword_table)
|
self._save_dataset_keyword_table(keyword_table)
|
||||||
|
|
||||||
|
def multi_create_segment_keywords(self, pre_segment_data_list: list):
|
||||||
|
keyword_table_handler = JiebaKeywordTableHandler()
|
||||||
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
|
for pre_segment_data in pre_segment_data_list:
|
||||||
|
segment = pre_segment_data['segment']
|
||||||
|
if pre_segment_data['keywords']:
|
||||||
|
segment.keywords = pre_segment_data['keywords']
|
||||||
|
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
|
||||||
|
pre_segment_data['keywords'])
|
||||||
|
else:
|
||||||
|
keywords = keyword_table_handler.extract_keywords(segment.content,
|
||||||
|
self._config.max_keywords_per_chunk)
|
||||||
|
segment.keywords = list(keywords)
|
||||||
|
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
||||||
|
self._save_dataset_keyword_table(keyword_table)
|
||||||
|
|
||||||
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
|
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
|
||||||
keyword_table = self._get_dataset_keyword_table()
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||||
self._save_dataset_keyword_table(keyword_table)
|
self._save_dataset_keyword_table(keyword_table)
|
||||||
|
|
||||||
|
|
||||||
class KeywordTableRetriever(BaseRetriever, BaseModel):
|
class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||||
index: KeywordTableIndex
|
index: KeywordTableIndex
|
||||||
search_kwargs: dict = Field(default_factory=dict)
|
search_kwargs: dict = Field(default_factory=dict)
|
||||||
|
|
0
api/fields/__init__.py
Normal file
0
api/fields/__init__.py
Normal file
138
api/fields/app_fields.py
Normal file
138
api/fields/app_fields.py
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
from flask_restful import fields
|
||||||
|
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
app_detail_kernel_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'mode': fields.String,
|
||||||
|
'icon': fields.String,
|
||||||
|
'icon_background': fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
related_app_list = {
|
||||||
|
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
|
||||||
|
'total': fields.Integer,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config_fields = {
|
||||||
|
'opening_statement': fields.String,
|
||||||
|
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
|
||||||
|
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||||
|
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||||
|
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
|
||||||
|
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||||
|
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
||||||
|
'model': fields.Raw(attribute='model_dict'),
|
||||||
|
'user_input_form': fields.Raw(attribute='user_input_form_list'),
|
||||||
|
'dataset_query_variable': fields.String,
|
||||||
|
'pre_prompt': fields.String,
|
||||||
|
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
|
||||||
|
}
|
||||||
|
|
||||||
|
app_detail_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'mode': fields.String,
|
||||||
|
'icon': fields.String,
|
||||||
|
'icon_background': fields.String,
|
||||||
|
'enable_site': fields.Boolean,
|
||||||
|
'enable_api': fields.Boolean,
|
||||||
|
'api_rpm': fields.Integer,
|
||||||
|
'api_rph': fields.Integer,
|
||||||
|
'is_demo': fields.Boolean,
|
||||||
|
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
||||||
|
'created_at': TimestampField
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_config_fields = {
|
||||||
|
'prompt_template': fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config_partial_fields = {
|
||||||
|
'model': fields.Raw(attribute='model_dict'),
|
||||||
|
'pre_prompt': fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
app_partial_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'mode': fields.String,
|
||||||
|
'icon': fields.String,
|
||||||
|
'icon_background': fields.String,
|
||||||
|
'enable_site': fields.Boolean,
|
||||||
|
'enable_api': fields.Boolean,
|
||||||
|
'is_demo': fields.Boolean,
|
||||||
|
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
|
||||||
|
'created_at': TimestampField
|
||||||
|
}
|
||||||
|
|
||||||
|
app_pagination_fields = {
|
||||||
|
'page': fields.Integer,
|
||||||
|
'limit': fields.Integer(attribute='per_page'),
|
||||||
|
'total': fields.Integer,
|
||||||
|
'has_more': fields.Boolean(attribute='has_next'),
|
||||||
|
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
|
||||||
|
}
|
||||||
|
|
||||||
|
template_fields = {
|
||||||
|
'name': fields.String,
|
||||||
|
'icon': fields.String,
|
||||||
|
'icon_background': fields.String,
|
||||||
|
'description': fields.String,
|
||||||
|
'mode': fields.String,
|
||||||
|
'model_config': fields.Nested(model_config_fields),
|
||||||
|
}
|
||||||
|
|
||||||
|
template_list_fields = {
|
||||||
|
'data': fields.List(fields.Nested(template_fields)),
|
||||||
|
}
|
||||||
|
|
||||||
|
site_fields = {
|
||||||
|
'access_token': fields.String(attribute='code'),
|
||||||
|
'code': fields.String,
|
||||||
|
'title': fields.String,
|
||||||
|
'icon': fields.String,
|
||||||
|
'icon_background': fields.String,
|
||||||
|
'description': fields.String,
|
||||||
|
'default_language': fields.String,
|
||||||
|
'customize_domain': fields.String,
|
||||||
|
'copyright': fields.String,
|
||||||
|
'privacy_policy': fields.String,
|
||||||
|
'customize_token_strategy': fields.String,
|
||||||
|
'prompt_public': fields.Boolean,
|
||||||
|
'app_base_url': fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
app_detail_fields_with_site = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'mode': fields.String,
|
||||||
|
'icon': fields.String,
|
||||||
|
'icon_background': fields.String,
|
||||||
|
'enable_site': fields.Boolean,
|
||||||
|
'enable_api': fields.Boolean,
|
||||||
|
'api_rpm': fields.Integer,
|
||||||
|
'api_rph': fields.Integer,
|
||||||
|
'is_demo': fields.Boolean,
|
||||||
|
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
||||||
|
'site': fields.Nested(site_fields),
|
||||||
|
'api_base_url': fields.String,
|
||||||
|
'created_at': TimestampField
|
||||||
|
}
|
||||||
|
|
||||||
|
app_site_fields = {
|
||||||
|
'app_id': fields.String,
|
||||||
|
'access_token': fields.String(attribute='code'),
|
||||||
|
'code': fields.String,
|
||||||
|
'title': fields.String,
|
||||||
|
'icon': fields.String,
|
||||||
|
'icon_background': fields.String,
|
||||||
|
'description': fields.String,
|
||||||
|
'default_language': fields.String,
|
||||||
|
'customize_domain': fields.String,
|
||||||
|
'copyright': fields.String,
|
||||||
|
'privacy_policy': fields.String,
|
||||||
|
'customize_token_strategy': fields.String,
|
||||||
|
'prompt_public': fields.Boolean
|
||||||
|
}
|
182
api/fields/conversation_fields.py
Normal file
182
api/fields/conversation_fields.py
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
from flask_restful import fields
|
||||||
|
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
|
||||||
|
class MessageTextField(fields.Raw):
|
||||||
|
def format(self, value):
|
||||||
|
return value[0]['text'] if value else ''
|
||||||
|
|
||||||
|
|
||||||
|
account_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'email': fields.String
|
||||||
|
}
|
||||||
|
|
||||||
|
feedback_fields = {
|
||||||
|
'rating': fields.String,
|
||||||
|
'content': fields.String,
|
||||||
|
'from_source': fields.String,
|
||||||
|
'from_end_user_id': fields.String,
|
||||||
|
'from_account': fields.Nested(account_fields, allow_null=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
annotation_fields = {
|
||||||
|
'content': fields.String,
|
||||||
|
'account': fields.Nested(account_fields, allow_null=True),
|
||||||
|
'created_at': TimestampField
|
||||||
|
}
|
||||||
|
|
||||||
|
message_detail_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'conversation_id': fields.String,
|
||||||
|
'inputs': fields.Raw,
|
||||||
|
'query': fields.String,
|
||||||
|
'message': fields.Raw,
|
||||||
|
'message_tokens': fields.Integer,
|
||||||
|
'answer': fields.String,
|
||||||
|
'answer_tokens': fields.Integer,
|
||||||
|
'provider_response_latency': fields.Float,
|
||||||
|
'from_source': fields.String,
|
||||||
|
'from_end_user_id': fields.String,
|
||||||
|
'from_account_id': fields.String,
|
||||||
|
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
||||||
|
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||||
|
'created_at': TimestampField
|
||||||
|
}
|
||||||
|
|
||||||
|
feedback_stat_fields = {
|
||||||
|
'like': fields.Integer,
|
||||||
|
'dislike': fields.Integer
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config_fields = {
|
||||||
|
'opening_statement': fields.String,
|
||||||
|
'suggested_questions': fields.Raw,
|
||||||
|
'model': fields.Raw,
|
||||||
|
'user_input_form': fields.Raw,
|
||||||
|
'pre_prompt': fields.String,
|
||||||
|
'agent_mode': fields.Raw,
|
||||||
|
}
|
||||||
|
|
||||||
|
simple_configs_fields = {
|
||||||
|
'prompt_template': fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
simple_model_config_fields = {
|
||||||
|
'model': fields.Raw(attribute='model_dict'),
|
||||||
|
'pre_prompt': fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
simple_message_detail_fields = {
|
||||||
|
'inputs': fields.Raw,
|
||||||
|
'query': fields.String,
|
||||||
|
'message': MessageTextField,
|
||||||
|
'answer': fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'status': fields.String,
|
||||||
|
'from_source': fields.String,
|
||||||
|
'from_end_user_id': fields.String,
|
||||||
|
'from_end_user_session_id': fields.String(),
|
||||||
|
'from_account_id': fields.String,
|
||||||
|
'read_at': TimestampField,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||||
|
'model_config': fields.Nested(simple_model_config_fields),
|
||||||
|
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||||
|
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||||
|
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation_pagination_fields = {
|
||||||
|
'page': fields.Integer,
|
||||||
|
'limit': fields.Integer(attribute='per_page'),
|
||||||
|
'total': fields.Integer,
|
||||||
|
'has_more': fields.Boolean(attribute='has_next'),
|
||||||
|
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation_message_detail_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'status': fields.String,
|
||||||
|
'from_source': fields.String,
|
||||||
|
'from_end_user_id': fields.String,
|
||||||
|
'from_account_id': fields.String,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'model_config': fields.Nested(model_config_fields),
|
||||||
|
'message': fields.Nested(message_detail_fields, attribute='first_message'),
|
||||||
|
}
|
||||||
|
|
||||||
|
simple_model_config_fields = {
|
||||||
|
'model': fields.Raw(attribute='model_dict'),
|
||||||
|
'pre_prompt': fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation_with_summary_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'status': fields.String,
|
||||||
|
'from_source': fields.String,
|
||||||
|
'from_end_user_id': fields.String,
|
||||||
|
'from_end_user_session_id': fields.String,
|
||||||
|
'from_account_id': fields.String,
|
||||||
|
'summary': fields.String(attribute='summary_or_query'),
|
||||||
|
'read_at': TimestampField,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'annotated': fields.Boolean,
|
||||||
|
'model_config': fields.Nested(simple_model_config_fields),
|
||||||
|
'message_count': fields.Integer,
|
||||||
|
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||||
|
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation_with_summary_pagination_fields = {
|
||||||
|
'page': fields.Integer,
|
||||||
|
'limit': fields.Integer(attribute='per_page'),
|
||||||
|
'total': fields.Integer,
|
||||||
|
'has_more': fields.Boolean(attribute='has_next'),
|
||||||
|
'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items')
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation_detail_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'status': fields.String,
|
||||||
|
'from_source': fields.String,
|
||||||
|
'from_end_user_id': fields.String,
|
||||||
|
'from_account_id': fields.String,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'annotated': fields.Boolean,
|
||||||
|
'model_config': fields.Nested(model_config_fields),
|
||||||
|
'message_count': fields.Integer,
|
||||||
|
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||||
|
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
simple_conversation_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'inputs': fields.Raw,
|
||||||
|
'status': fields.String,
|
||||||
|
'introduction': fields.String,
|
||||||
|
'created_at': TimestampField
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation_infinite_scroll_pagination_fields = {
|
||||||
|
'limit': fields.Integer,
|
||||||
|
'has_more': fields.Boolean,
|
||||||
|
'data': fields.List(fields.Nested(simple_conversation_fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation_with_model_config_fields = {
|
||||||
|
**simple_conversation_fields,
|
||||||
|
'model_config': fields.Raw,
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation_with_model_config_infinite_scroll_pagination_fields = {
|
||||||
|
'limit': fields.Integer,
|
||||||
|
'has_more': fields.Boolean,
|
||||||
|
'data': fields.List(fields.Nested(conversation_with_model_config_fields))
|
||||||
|
}
|
65
api/fields/data_source_fields.py
Normal file
65
api/fields/data_source_fields.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
from flask_restful import fields
|
||||||
|
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
integrate_icon_fields = {
|
||||||
|
'type': fields.String,
|
||||||
|
'url': fields.String,
|
||||||
|
'emoji': fields.String
|
||||||
|
}
|
||||||
|
|
||||||
|
integrate_page_fields = {
|
||||||
|
'page_name': fields.String,
|
||||||
|
'page_id': fields.String,
|
||||||
|
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||||
|
'is_bound': fields.Boolean,
|
||||||
|
'parent_id': fields.String,
|
||||||
|
'type': fields.String
|
||||||
|
}
|
||||||
|
|
||||||
|
integrate_workspace_fields = {
|
||||||
|
'workspace_name': fields.String,
|
||||||
|
'workspace_id': fields.String,
|
||||||
|
'workspace_icon': fields.String,
|
||||||
|
'pages': fields.List(fields.Nested(integrate_page_fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
integrate_notion_info_list_fields = {
|
||||||
|
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
|
||||||
|
}
|
||||||
|
|
||||||
|
integrate_icon_fields = {
|
||||||
|
'type': fields.String,
|
||||||
|
'url': fields.String,
|
||||||
|
'emoji': fields.String
|
||||||
|
}
|
||||||
|
|
||||||
|
integrate_page_fields = {
|
||||||
|
'page_name': fields.String,
|
||||||
|
'page_id': fields.String,
|
||||||
|
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||||
|
'parent_id': fields.String,
|
||||||
|
'type': fields.String
|
||||||
|
}
|
||||||
|
|
||||||
|
integrate_workspace_fields = {
|
||||||
|
'workspace_name': fields.String,
|
||||||
|
'workspace_id': fields.String,
|
||||||
|
'workspace_icon': fields.String,
|
||||||
|
'pages': fields.List(fields.Nested(integrate_page_fields)),
|
||||||
|
'total': fields.Integer
|
||||||
|
}
|
||||||
|
|
||||||
|
integrate_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'provider': fields.String,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'is_bound': fields.Boolean,
|
||||||
|
'disabled': fields.Boolean,
|
||||||
|
'link': fields.String,
|
||||||
|
'source_info': fields.Nested(integrate_workspace_fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
integrate_list_fields = {
|
||||||
|
'data': fields.List(fields.Nested(integrate_fields)),
|
||||||
|
}
|
43
api/fields/dataset_fields.py
Normal file
43
api/fields/dataset_fields.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
from flask_restful import fields
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
dataset_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'description': fields.String,
|
||||||
|
'permission': fields.String,
|
||||||
|
'data_source_type': fields.String,
|
||||||
|
'indexing_technique': fields.String,
|
||||||
|
'created_by': fields.String,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset_detail_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'description': fields.String,
|
||||||
|
'provider': fields.String,
|
||||||
|
'permission': fields.String,
|
||||||
|
'data_source_type': fields.String,
|
||||||
|
'indexing_technique': fields.String,
|
||||||
|
'app_count': fields.Integer,
|
||||||
|
'document_count': fields.Integer,
|
||||||
|
'word_count': fields.Integer,
|
||||||
|
'created_by': fields.String,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'updated_by': fields.String,
|
||||||
|
'updated_at': TimestampField,
|
||||||
|
'embedding_model': fields.String,
|
||||||
|
'embedding_model_provider': fields.String,
|
||||||
|
'embedding_available': fields.Boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset_query_detail_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"content": fields.String,
|
||||||
|
"source": fields.String,
|
||||||
|
"source_app_id": fields.String,
|
||||||
|
"created_by_role": fields.String,
|
||||||
|
"created_by": fields.String,
|
||||||
|
"created_at": TimestampField
|
||||||
|
}
|
76
api/fields/document_fields.py
Normal file
76
api/fields/document_fields.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
from flask_restful import fields
|
||||||
|
|
||||||
|
from fields.dataset_fields import dataset_fields
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
document_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'position': fields.Integer,
|
||||||
|
'data_source_type': fields.String,
|
||||||
|
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||||
|
'dataset_process_rule_id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'created_from': fields.String,
|
||||||
|
'created_by': fields.String,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'tokens': fields.Integer,
|
||||||
|
'indexing_status': fields.String,
|
||||||
|
'error': fields.String,
|
||||||
|
'enabled': fields.Boolean,
|
||||||
|
'disabled_at': TimestampField,
|
||||||
|
'disabled_by': fields.String,
|
||||||
|
'archived': fields.Boolean,
|
||||||
|
'display_status': fields.String,
|
||||||
|
'word_count': fields.Integer,
|
||||||
|
'hit_count': fields.Integer,
|
||||||
|
'doc_form': fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
document_with_segments_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'position': fields.Integer,
|
||||||
|
'data_source_type': fields.String,
|
||||||
|
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||||
|
'dataset_process_rule_id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'created_from': fields.String,
|
||||||
|
'created_by': fields.String,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'tokens': fields.Integer,
|
||||||
|
'indexing_status': fields.String,
|
||||||
|
'error': fields.String,
|
||||||
|
'enabled': fields.Boolean,
|
||||||
|
'disabled_at': TimestampField,
|
||||||
|
'disabled_by': fields.String,
|
||||||
|
'archived': fields.Boolean,
|
||||||
|
'display_status': fields.String,
|
||||||
|
'word_count': fields.Integer,
|
||||||
|
'hit_count': fields.Integer,
|
||||||
|
'completed_segments': fields.Integer,
|
||||||
|
'total_segments': fields.Integer
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset_and_document_fields = {
|
||||||
|
'dataset': fields.Nested(dataset_fields),
|
||||||
|
'documents': fields.List(fields.Nested(document_fields)),
|
||||||
|
'batch': fields.String
|
||||||
|
}
|
||||||
|
|
||||||
|
document_status_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'indexing_status': fields.String,
|
||||||
|
'processing_started_at': TimestampField,
|
||||||
|
'parsing_completed_at': TimestampField,
|
||||||
|
'cleaning_completed_at': TimestampField,
|
||||||
|
'splitting_completed_at': TimestampField,
|
||||||
|
'completed_at': TimestampField,
|
||||||
|
'paused_at': TimestampField,
|
||||||
|
'error': fields.String,
|
||||||
|
'stopped_at': TimestampField,
|
||||||
|
'completed_segments': fields.Integer,
|
||||||
|
'total_segments': fields.Integer,
|
||||||
|
}
|
||||||
|
|
||||||
|
document_status_fields_list = {
|
||||||
|
'data': fields.List(fields.Nested(document_status_fields))
|
||||||
|
}
|
18
api/fields/file_fields.py
Normal file
18
api/fields/file_fields.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
from flask_restful import fields
|
||||||
|
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
upload_config_fields = {
|
||||||
|
'file_size_limit': fields.Integer,
|
||||||
|
'batch_count_limit': fields.Integer
|
||||||
|
}
|
||||||
|
|
||||||
|
file_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'size': fields.Integer,
|
||||||
|
'extension': fields.String,
|
||||||
|
'mime_type': fields.String,
|
||||||
|
'created_by': fields.String,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
}
|
41
api/fields/hit_testing_fields.py
Normal file
41
api/fields/hit_testing_fields.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
from flask_restful import fields
|
||||||
|
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
document_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'data_source_type': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'doc_type': fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
segment_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'position': fields.Integer,
|
||||||
|
'document_id': fields.String,
|
||||||
|
'content': fields.String,
|
||||||
|
'answer': fields.String,
|
||||||
|
'word_count': fields.Integer,
|
||||||
|
'tokens': fields.Integer,
|
||||||
|
'keywords': fields.List(fields.String),
|
||||||
|
'index_node_id': fields.String,
|
||||||
|
'index_node_hash': fields.String,
|
||||||
|
'hit_count': fields.Integer,
|
||||||
|
'enabled': fields.Boolean,
|
||||||
|
'disabled_at': TimestampField,
|
||||||
|
'disabled_by': fields.String,
|
||||||
|
'status': fields.String,
|
||||||
|
'created_by': fields.String,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'indexing_at': TimestampField,
|
||||||
|
'completed_at': TimestampField,
|
||||||
|
'error': fields.String,
|
||||||
|
'stopped_at': TimestampField,
|
||||||
|
'document': fields.Nested(document_fields),
|
||||||
|
}
|
||||||
|
|
||||||
|
hit_testing_record_fields = {
|
||||||
|
'segment': fields.Nested(segment_fields),
|
||||||
|
'score': fields.Float,
|
||||||
|
'tsne_position': fields.Raw
|
||||||
|
}
|
25
api/fields/installed_app_fields.py
Normal file
25
api/fields/installed_app_fields.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
from flask_restful import fields
|
||||||
|
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
app_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'mode': fields.String,
|
||||||
|
'icon': fields.String,
|
||||||
|
'icon_background': fields.String
|
||||||
|
}
|
||||||
|
|
||||||
|
installed_app_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'app': fields.Nested(app_fields),
|
||||||
|
'app_owner_tenant_id': fields.String,
|
||||||
|
'is_pinned': fields.Boolean,
|
||||||
|
'last_used_at': TimestampField,
|
||||||
|
'editable': fields.Boolean,
|
||||||
|
'uninstallable': fields.Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
installed_app_list_fields = {
|
||||||
|
'installed_apps': fields.List(fields.Nested(installed_app_fields))
|
||||||
|
}
|
43
api/fields/message_fields.py
Normal file
43
api/fields/message_fields.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
from flask_restful import fields
|
||||||
|
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
feedback_fields = {
|
||||||
|
'rating': fields.String
|
||||||
|
}
|
||||||
|
|
||||||
|
retriever_resource_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'message_id': fields.String,
|
||||||
|
'position': fields.Integer,
|
||||||
|
'dataset_id': fields.String,
|
||||||
|
'dataset_name': fields.String,
|
||||||
|
'document_id': fields.String,
|
||||||
|
'document_name': fields.String,
|
||||||
|
'data_source_type': fields.String,
|
||||||
|
'segment_id': fields.String,
|
||||||
|
'score': fields.Float,
|
||||||
|
'hit_count': fields.Integer,
|
||||||
|
'word_count': fields.Integer,
|
||||||
|
'segment_position': fields.Integer,
|
||||||
|
'index_node_hash': fields.String,
|
||||||
|
'content': fields.String,
|
||||||
|
'created_at': TimestampField
|
||||||
|
}
|
||||||
|
|
||||||
|
message_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'conversation_id': fields.String,
|
||||||
|
'inputs': fields.Raw,
|
||||||
|
'query': fields.String,
|
||||||
|
'answer': fields.String,
|
||||||
|
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||||
|
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||||
|
'created_at': TimestampField
|
||||||
|
}
|
||||||
|
|
||||||
|
message_infinite_scroll_pagination_fields = {
|
||||||
|
'limit': fields.Integer,
|
||||||
|
'has_more': fields.Boolean,
|
||||||
|
'data': fields.List(fields.Nested(message_fields))
|
||||||
|
}
|
32
api/fields/segment_fields.py
Normal file
32
api/fields/segment_fields.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
from flask_restful import fields
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
segment_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'position': fields.Integer,
|
||||||
|
'document_id': fields.String,
|
||||||
|
'content': fields.String,
|
||||||
|
'answer': fields.String,
|
||||||
|
'word_count': fields.Integer,
|
||||||
|
'tokens': fields.Integer,
|
||||||
|
'keywords': fields.List(fields.String),
|
||||||
|
'index_node_id': fields.String,
|
||||||
|
'index_node_hash': fields.String,
|
||||||
|
'hit_count': fields.Integer,
|
||||||
|
'enabled': fields.Boolean,
|
||||||
|
'disabled_at': TimestampField,
|
||||||
|
'disabled_by': fields.String,
|
||||||
|
'status': fields.String,
|
||||||
|
'created_by': fields.String,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'indexing_at': TimestampField,
|
||||||
|
'completed_at': TimestampField,
|
||||||
|
'error': fields.String,
|
||||||
|
'stopped_at': TimestampField
|
||||||
|
}
|
||||||
|
|
||||||
|
segment_list_response = {
|
||||||
|
'data': fields.List(fields.Nested(segment_fields)),
|
||||||
|
'has_more': fields.Boolean,
|
||||||
|
'limit': fields.Integer
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""add_tenant_id_in_api_token
|
||||||
|
|
||||||
|
Revision ID: 2e9819ca5b28
|
||||||
|
Revises: 6e2cfb077b04
|
||||||
|
Create Date: 2023-09-22 15:41:01.243183
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '2e9819ca5b28'
|
||||||
|
down_revision = 'ab23c11305d4'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True))
|
||||||
|
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
|
||||||
|
batch_op.drop_column('dataset_id')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True))
|
||||||
|
batch_op.drop_index('api_token_tenant_idx')
|
||||||
|
batch_op.drop_column('tenant_id')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
|
@ -629,12 +629,13 @@ class ApiToken(db.Model):
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint('id', name='api_token_pkey'),
|
db.PrimaryKeyConstraint('id', name='api_token_pkey'),
|
||||||
db.Index('api_token_app_id_type_idx', 'app_id', 'type'),
|
db.Index('api_token_app_id_type_idx', 'app_id', 'type'),
|
||||||
db.Index('api_token_token_idx', 'token', 'type')
|
db.Index('api_token_token_idx', 'token', 'type'),
|
||||||
|
db.Index('api_token_tenant_idx', 'tenant_id', 'type')
|
||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||||
app_id = db.Column(UUID, nullable=True)
|
app_id = db.Column(UUID, nullable=True)
|
||||||
dataset_id = db.Column(UUID, nullable=True)
|
tenant_id = db.Column(UUID, nullable=True)
|
||||||
type = db.Column(db.String(16), nullable=False)
|
type = db.Column(db.String(16), nullable=False)
|
||||||
token = db.Column(db.String(255), nullable=False)
|
token = db.Column(db.String(255), nullable=False)
|
||||||
last_used_at = db.Column(db.DateTime, nullable=True)
|
last_used_at = db.Column(db.DateTime, nullable=True)
|
||||||
|
|
|
@ -96,7 +96,7 @@ class DatasetService:
|
||||||
embedding_model = None
|
embedding_model = None
|
||||||
if indexing_technique == 'high_quality':
|
if indexing_technique == 'high_quality':
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
tenant_id=current_user.current_tenant_id
|
tenant_id=tenant_id
|
||||||
)
|
)
|
||||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||||
|
@ -477,6 +477,7 @@ class DocumentService:
|
||||||
)
|
)
|
||||||
dataset.collection_binding_id = dataset_collection_binding.id
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
|
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||||
if 'original_document_id' in document_data and document_data["original_document_id"]:
|
if 'original_document_id' in document_data and document_data["original_document_id"]:
|
||||||
|
@ -626,6 +627,9 @@ class DocumentService:
|
||||||
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
||||||
if document.display_status != 'available':
|
if document.display_status != 'available':
|
||||||
raise ValueError("Document is not available")
|
raise ValueError("Document is not available")
|
||||||
|
# update document name
|
||||||
|
if 'name' in document_data and document_data['name']:
|
||||||
|
document.name = document_data['name']
|
||||||
# save process rule
|
# save process rule
|
||||||
if 'process_rule' in document_data and document_data['process_rule']:
|
if 'process_rule' in document_data and document_data['process_rule']:
|
||||||
process_rule = document_data["process_rule"]
|
process_rule = document_data["process_rule"]
|
||||||
|
@ -767,7 +771,7 @@ class DocumentService:
|
||||||
return dataset, documents, batch
|
return dataset, documents, batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def document_create_args_validate(cls, args: dict):
|
def document_create_args_validate(cls, args: dict):
|
||||||
if 'original_document_id' not in args or not args['original_document_id']:
|
if 'original_document_id' not in args or not args['original_document_id']:
|
||||||
DocumentService.data_source_args_validate(args)
|
DocumentService.data_source_args_validate(args)
|
||||||
DocumentService.process_rule_args_validate(args)
|
DocumentService.process_rule_args_validate(args)
|
||||||
|
@ -1014,6 +1018,66 @@ class SegmentService:
|
||||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
|
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
|
||||||
return segment
|
return segment
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
|
||||||
|
embedding_model = None
|
||||||
|
if dataset.indexing_technique == 'high_quality':
|
||||||
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
|
model_name=dataset.embedding_model
|
||||||
|
)
|
||||||
|
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||||
|
DocumentSegment.document_id == document.id
|
||||||
|
).scalar()
|
||||||
|
pre_segment_data_list = []
|
||||||
|
segment_data_list = []
|
||||||
|
for segment_item in segments:
|
||||||
|
content = segment_item['content']
|
||||||
|
doc_id = str(uuid.uuid4())
|
||||||
|
segment_hash = helper.generate_text_hash(content)
|
||||||
|
tokens = 0
|
||||||
|
if dataset.indexing_technique == 'high_quality' and embedding_model:
|
||||||
|
# calc embedding use tokens
|
||||||
|
tokens = embedding_model.get_num_tokens(content)
|
||||||
|
segment_document = DocumentSegment(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
dataset_id=document.dataset_id,
|
||||||
|
document_id=document.id,
|
||||||
|
index_node_id=doc_id,
|
||||||
|
index_node_hash=segment_hash,
|
||||||
|
position=max_position + 1 if max_position else 1,
|
||||||
|
content=content,
|
||||||
|
word_count=len(content),
|
||||||
|
tokens=tokens,
|
||||||
|
status='completed',
|
||||||
|
indexing_at=datetime.datetime.utcnow(),
|
||||||
|
completed_at=datetime.datetime.utcnow(),
|
||||||
|
created_by=current_user.id
|
||||||
|
)
|
||||||
|
if document.doc_form == 'qa_model':
|
||||||
|
segment_document.answer = segment_item['answer']
|
||||||
|
db.session.add(segment_document)
|
||||||
|
segment_data_list.append(segment_document)
|
||||||
|
pre_segment_data = {
|
||||||
|
'segment': segment_document,
|
||||||
|
'keywords': segment_item['keywords']
|
||||||
|
}
|
||||||
|
pre_segment_data_list.append(pre_segment_data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# save vector index
|
||||||
|
VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("create segment index failed")
|
||||||
|
for segment_document in segment_data_list:
|
||||||
|
segment_document.enabled = False
|
||||||
|
segment_document.disabled_at = datetime.datetime.utcnow()
|
||||||
|
segment_document.status = 'error'
|
||||||
|
segment_document.error = str(e)
|
||||||
|
db.session.commit()
|
||||||
|
return segment_data_list
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
|
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
|
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
|
||||||
'app', 'completion', 'audio'
|
'app', 'completion', 'audio', 'file'
|
||||||
]
|
]
|
||||||
|
|
||||||
from . import *
|
from . import *
|
||||||
|
|
|
@ -3,3 +3,11 @@ from services.errors.base import BaseServiceError
|
||||||
|
|
||||||
class FileNotExistsError(BaseServiceError):
|
class FileNotExistsError(BaseServiceError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FileTooLargeError(BaseServiceError):
|
||||||
|
description = "{message}"
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedFileTypeError(BaseServiceError):
|
||||||
|
pass
|
||||||
|
|
123
api/services/file_service.py
Normal file
123
api/services/file_service.py
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
import datetime
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
from flask import request, current_app
|
||||||
|
from flask_login import current_user
|
||||||
|
from werkzeug.datastructures import FileStorage
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from core.data_loader.file_extractor import FileExtractor
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import UploadFile
|
||||||
|
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||||
|
|
||||||
|
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||||
|
PREVIEW_WORDS_LIMIT = 3000
|
||||||
|
cache = TTLCache(maxsize=None, ttl=30)
|
||||||
|
|
||||||
|
|
||||||
|
class FileService:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def upload_file(file: FileStorage) -> UploadFile:
|
||||||
|
# read file content
|
||||||
|
file_content = file.read()
|
||||||
|
# get file size
|
||||||
|
file_size = len(file_content)
|
||||||
|
|
||||||
|
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||||
|
if file_size > file_size_limit:
|
||||||
|
message = f'File size exceeded. {file_size} > {file_size_limit}'
|
||||||
|
raise FileTooLargeError(message)
|
||||||
|
|
||||||
|
extension = file.filename.split('.')[-1]
|
||||||
|
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||||
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
# user uuid as file name
|
||||||
|
file_uuid = str(uuid.uuid4())
|
||||||
|
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
|
||||||
|
|
||||||
|
# save file to storage
|
||||||
|
storage.save(file_key, file_content)
|
||||||
|
|
||||||
|
# save file to db
|
||||||
|
config = current_app.config
|
||||||
|
upload_file = UploadFile(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
storage_type=config['STORAGE_TYPE'],
|
||||||
|
key=file_key,
|
||||||
|
name=file.filename,
|
||||||
|
size=file_size,
|
||||||
|
extension=extension,
|
||||||
|
mime_type=file.mimetype,
|
||||||
|
created_by=current_user.id,
|
||||||
|
created_at=datetime.datetime.utcnow(),
|
||||||
|
used=False,
|
||||||
|
hash=hashlib.sha3_256(file_content).hexdigest()
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(upload_file)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return upload_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def upload_text(text: str, text_name: str) -> UploadFile:
|
||||||
|
# user uuid as file name
|
||||||
|
file_uuid = str(uuid.uuid4())
|
||||||
|
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt'
|
||||||
|
|
||||||
|
# save file to storage
|
||||||
|
storage.save(file_key, text.encode('utf-8'))
|
||||||
|
|
||||||
|
# save file to db
|
||||||
|
config = current_app.config
|
||||||
|
upload_file = UploadFile(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
storage_type=config['STORAGE_TYPE'],
|
||||||
|
key=file_key,
|
||||||
|
name=text_name + '.txt',
|
||||||
|
size=len(text),
|
||||||
|
extension='txt',
|
||||||
|
mime_type='text/plain',
|
||||||
|
created_by=current_user.id,
|
||||||
|
created_at=datetime.datetime.utcnow(),
|
||||||
|
used=True,
|
||||||
|
used_by=current_user.id,
|
||||||
|
used_at=datetime.datetime.utcnow()
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(upload_file)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return upload_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_preview(file_id: str) -> str:
|
||||||
|
# get file storage key
|
||||||
|
key = file_id + request.path
|
||||||
|
cached_response = cache.get(key)
|
||||||
|
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
|
||||||
|
return cached_response['response']
|
||||||
|
|
||||||
|
upload_file = db.session.query(UploadFile) \
|
||||||
|
.filter(UploadFile.id == file_id) \
|
||||||
|
.first()
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
raise NotFound("File not found")
|
||||||
|
|
||||||
|
# extract text from file
|
||||||
|
extension = upload_file.extension
|
||||||
|
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||||
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
text = FileExtractor.load(upload_file, return_text=True)
|
||||||
|
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||||
|
|
||||||
|
return text
|
|
@ -35,6 +35,32 @@ class VectorService:
|
||||||
else:
|
else:
|
||||||
index.add_texts([document])
|
index.add_texts([document])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset):
|
||||||
|
documents = []
|
||||||
|
for pre_segment_data in pre_segment_data_list:
|
||||||
|
segment = pre_segment_data['segment']
|
||||||
|
document = Document(
|
||||||
|
page_content=segment.content,
|
||||||
|
metadata={
|
||||||
|
"doc_id": segment.index_node_id,
|
||||||
|
"doc_hash": segment.index_node_hash,
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
documents.append(document)
|
||||||
|
|
||||||
|
# save vector index
|
||||||
|
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
|
if index:
|
||||||
|
index.add_texts(documents, duplicate_check=True)
|
||||||
|
|
||||||
|
# save keyword index
|
||||||
|
keyword_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
if keyword_index:
|
||||||
|
keyword_index.multi_create_segment_keywords(pre_segment_data_list)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
|
def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
|
||||||
# update segment index task
|
# update segment index task
|
||||||
|
|
Loading…
Reference in New Issue
Block a user