parent-child support for all event

This commit is contained in:
jyong 2024-10-18 16:19:40 +08:00
parent a17002313b
commit 25b550c0c5
24 changed files with 267 additions and 172 deletions

View File

@ -713,7 +713,8 @@ class DatasetPermissionUserListApi(Resource):
return {
"data": partial_members_list,
}, 200
api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
@ -727,4 +728,4 @@ api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>")
api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")

View File

@ -332,7 +332,6 @@ class DatasetInitApi(Resource):
DocumentService.document_create_args_validate(knowledge_config)
try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
)

View File

@ -1,5 +1,4 @@
import uuid
from datetime import datetime, timezone
import pandas as pd
from flask import request
@ -10,7 +9,13 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import ChildChunkDeleteIndexError, ChildChunkIndexingError, InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.datasets.error import (
ChildChunkDeleteIndexError,
ChildChunkIndexingError,
InvalidActionError,
NoFileUploadedError,
TooManyFilesError,
)
from controllers.console.setup import setup_required
from controllers.console.wraps import (
account_initialization_required,
@ -22,15 +27,14 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields, child_chunk_fields
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required
from models.dataset import ChildChunk, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError, ChildChunkIndexingError as ChildChunkIndexingServiceError
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
class DatasetDocumentSegmentListApi(Resource):
@ -140,7 +144,6 @@ class DatasetDocumentSegmentListApi(Resource):
return {"result": "success"}, 200
class DatasetDocumentSegmentApi(Resource):
@setup_required
@login_required
@ -188,14 +191,13 @@ class DatasetDocumentSegmentApi(Resource):
cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later")
try:
try:
SegmentService.update_segments_status(segment_ids, action, dataset, document)
except Exception as e:
raise InvalidActionError(str(e))
return {"result": "success"}, 200
class DatasetDocumentSegmentAddApi(Resource):
@setup_required
@login_required
@ -301,7 +303,9 @@ class DatasetDocumentSegmentUpdateApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
parser.add_argument("regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json")
parser.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
@ -406,6 +410,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
class ChildChunkAddApi(Resource):
@setup_required
@login_required
@ -576,6 +581,7 @@ class ChildChunkUpdateApi(Resource):
raise ChildChunkDeleteIndexError(str(e))
return {"result": "success"}, 200
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/<string:action>")
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")

View File

@ -1,14 +1,18 @@
from pydantic import BaseModel
from typing import Optional
from pydantic import BaseModel
class PreviewDetail(BaseModel):
content: str
child_chunks: Optional[list[str]] = None
class QAPreviewDetail(BaseModel):
question: str
answer: str
class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]

View File

@ -8,7 +8,7 @@ import time
import uuid
from typing import Optional, cast
from flask import Flask, current_app
from flask import current_app
from flask_login import current_user
from sqlalchemy.orm.exc import ObjectDeletedError
@ -112,7 +112,7 @@ class IndexingRunner:
for document_segment in document_segments:
db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
@ -189,7 +189,7 @@ class IndexingRunner:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
@ -286,7 +286,7 @@ class IndexingRunner:
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule.to_dict(),
tenant_id=current_user.current_tenant_id,
doc_language=doc_language,
doc_language=doc_language,
)
total_segments += len(documents)
for document in documents:
@ -304,7 +304,9 @@ class IndexingRunner:
)
document_qa_list = self.format_split_text(response)
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=document_qa_list, preview=preview_texts)
return IndexingEstimate(
total_segments=total_segments * 20, qa_preview=document_qa_list, preview=preview_texts
)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(
@ -398,8 +400,11 @@ class IndexingRunner:
@staticmethod
def _get_splitter(
processing_rule_mode: str, max_tokens: int, chunk_overlap: int, separator: str,
embedding_model_instance: Optional[ModelInstance]
processing_rule_mode: str,
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
@ -409,7 +414,7 @@ class IndexingRunner:
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
if separator:
separator = separator.replace("\\n", "\n")

View File

@ -1,4 +1,5 @@
from typing import Optional
from pydantic import BaseModel
from models.dataset import ChildChunk, DocumentSegment
@ -6,6 +7,7 @@ from models.dataset import ChildChunk, DocumentSegment
class RetrievalSegments(BaseModel):
"""Retrieval segments."""
segment: DocumentSegment
child_chunks: Optional[list[ChildChunk]] = None
score: Optional[float] = None

View File

@ -12,7 +12,8 @@ from core.rag.models.document import Document
from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment, Document as DatasetDocument
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
@ -250,7 +251,7 @@ class RetrievalService:
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed"
DocumentSegment.status == "completed",
)
.first()
)
@ -261,7 +262,7 @@ class RetrievalService:
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
map_detail = {
"max_score": document.metadata.get("score", .0),
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk],
}
segment_child_map[segment.id] = map_detail
@ -271,7 +272,9 @@ class RetrievalService:
records.append(record)
else:
segment_child_map[segment.id]["child_chunks"].append(child_chunk)
segment_child_map[segment.id]["max_score"] = max(segment_child_map[segment.id]["max_score"], document.metadata.get("score", .0))
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
continue
else:
@ -301,4 +304,4 @@ class RetrievalService:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id]
record["score"] = segment_child_map[record["segment"].id]["max_score"]
return [RetrievalSegments(**record) for record in records]
return [RetrievalSegments(**record) for record in records]

View File

@ -4,4 +4,4 @@ from enum import Enum
class IndexType(str, Enum):
PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "hierarchical_model"
PARENT_CHILD_INDEX = "hierarchical_model"

View File

@ -45,7 +45,14 @@ class BaseIndexProcessor(ABC):
) -> list[Document]:
raise NotImplementedError
def _get_splitter(self, processing_rule_mode: str, max_tokens: int, chunk_overlap: int, separator: str, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
def _get_splitter(
self,
processing_rule_mode: str,
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
@ -54,7 +61,7 @@ class BaseIndexProcessor(ABC):
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
if separator:
separator = separator.replace("\\n", "\n")

View File

@ -19,7 +19,10 @@ from services.entities.knowledge_entities.knowledge_entities import Rule
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=(kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical")
extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
)
return text_docs
@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance")
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
all_documents = []
for document in documents:
@ -73,7 +76,6 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
else:
keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)

View File

@ -5,22 +5,25 @@ from typing import Optional
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, DocumentSegment
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from extensions.ext_database import db
class ParentChildIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=(kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical")
extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
)
return text_docs
@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance")
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
all_documents = []
for document in documents:
@ -60,7 +63,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if len(page_content) > 0:
document_node.page_content = page_content
# parse document to child nodes
child_nodes = self._split_child_nodes(document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance"))
child_nodes = self._split_child_nodes(
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
document_node.children = child_nodes
split_documents.append(document_node)
all_documents.extend(split_documents)
@ -68,7 +73,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
page_content = "\n".join([document.page_content for document in documents])
document = Document(page_content=page_content, metadata=documents[0].metadata)
# parse document to child nodes
child_nodes = self._split_child_nodes(document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance"))
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
document.children = child_nodes
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
@ -84,7 +91,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
for document in documents:
child_documents = document.children
if child_documents:
formatted_child_documents = [Document(**child_document.model_dump()) for child_document in child_documents]
formatted_child_documents = [
Document(**child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
@ -98,7 +107,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id
ChildChunk.dataset_id == dataset.id,
)
.all()
)
@ -106,10 +115,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
vector.delete_by_ids(child_node_ids)
else:
vector.delete()
delete_child_chunks = kwargs.get("delete_child_chunks") or False
if delete_child_chunks:
db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)).delete()
if delete_child_chunks:
db.session.query(ChildChunk).filter(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete()
db.session.commit()
def retrieve(
@ -139,14 +150,20 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def _split_child_nodes(self, document_node: Document, rules: Rule, process_rule_mode: str, embedding_model_instance: Optional[ModelInstance]) -> list[ChildDocument]:
def _split_child_nodes(
self,
document_node: Document,
rules: Rule,
process_rule_mode: str,
embedding_model_instance: Optional[ModelInstance],
) -> list[ChildDocument]:
child_splitter = self._get_splitter(
processing_rule_mode=process_rule_mode,
max_tokens=rules.subchunk_segmentation.max_tokens,
chunk_overlap=rules.subchunk_segmentation.chunk_overlap,
separator=rules.subchunk_segmentation.separator,
embedding_model_instance=embedding_model_instance
embedding_model_instance=embedding_model_instance,
)
# parse document to child nodes
child_nodes = []
@ -156,8 +173,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(child_document_node.page_content)
child_document = ChildDocument(
page_content=child_document_node.page_content,
metadata=document_node.metadata
page_content=child_document_node.page_content, metadata=document_node.metadata
)
child_document.metadata["doc_id"] = doc_id
child_document.metadata["doc_hash"] = hash

View File

@ -26,7 +26,10 @@ from services.entities.knowledge_entities.knowledge_entities import Rule
class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=(kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical")
extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
)
return text_docs
@ -38,8 +41,8 @@ class QAIndexProcessor(BaseIndexProcessor):
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance")
)
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
# Split the text documents into nodes.
all_documents = []

View File

@ -4,6 +4,7 @@ from typing import Any, Optional
from pydantic import BaseModel, Field
class ChildDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""

View File

@ -19,7 +19,7 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.dataset import Dataset, Document
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)

View File

@ -558,7 +558,7 @@ class DocumentSegment(db.Model):
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)
.first()
)
@property
def child_chunks(self):
child_chunks = db.session.query(ChildChunk).filter(ChildChunk.segment_id == self.id).all()
@ -590,6 +590,7 @@ class DocumentSegment(db.Model):
return text
class ChildChunk(db.Model):
__tablename__ = "child_chunks"
__table_args__ = (
@ -625,11 +626,12 @@ class ChildChunk(db.Model):
@property
def document(self):
return db.session.query(Document).filter(Document.id == self.document_id).first()
@property
def segment(self):
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(db.Model):
__tablename__ = "app_dataset_joins"
__table_args__ = (

View File

@ -39,7 +39,12 @@ from models.dataset import (
)
from models.model import UploadFile
from models.source import DataSourceOauthBinding
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, KnowledgeConfig, RetrievalModel, SegmentUpdateArgs
from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs,
KnowledgeConfig,
RetrievalModel,
SegmentUpdateArgs,
)
from services.errors.account import InvalidActionError, NoPermissionError
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.dataset import DatasetNameDuplicateError
@ -57,7 +62,6 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
@ -711,9 +715,7 @@ class DocumentService:
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
if not dataset.indexing_technique:
if (
knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST
):
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")
dataset.indexing_technique = knowledge_config.indexing_technique
@ -748,7 +750,7 @@ class DocumentService:
# save process rule
if not dataset_process_rule:
process_rule = knowledge_config.process_rule
if process_rule.mode== "custom" or process_rule.mode== "hierarchical":
if process_rule.mode == "custom" or process_rule.mode == "hierarchical":
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
@ -1139,7 +1141,7 @@ class DocumentService:
retrieval_model=retrieval_model.model_dump_json() if retrieval_model else None,
)
db.session.add(dataset) # type: ignore
db.session.add(dataset) # type: ignore
db.session.flush()
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account)
@ -1209,7 +1211,7 @@ class DocumentService:
if knowledge_config.process_rule.rules.pre_processing_rules is None:
raise ValueError("Process rule pre_processing_rules is required")
unique_pre_processing_rule_dicts = {}
for pre_processing_rule in knowledge_config.process_rule.rules.pre_processing_rules:
if not pre_processing_rule.id:
@ -1230,11 +1232,13 @@ class DocumentService:
if not isinstance(knowledge_config.process_rule.rules.segmentation.separator, str):
raise ValueError("Process rule segmentation separator is invalid")
if not (knowledge_config.process_rule.mode == "hierarchical" and knowledge_config.process_rule.rules.parent_mode == "full-doc"):
if not (
knowledge_config.process_rule.mode == "hierarchical"
and knowledge_config.process_rule.rules.parent_mode == "full-doc"
):
if not knowledge_config.process_rule.rules.segmentation.max_tokens:
raise ValueError("Process rule segmentation max_tokens is required")
raise ValueError("Process rule segmentation max_tokens is required")
if not isinstance(knowledge_config.process_rule.rules.segmentation.max_tokens, int):
raise ValueError("Process rule segmentation max_tokens is invalid")
@ -1535,9 +1539,11 @@ class SegmentService:
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
VectorService.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, True)
.first()
)
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
@ -1569,7 +1575,7 @@ class SegmentService:
db.session.add(segment)
db.session.commit()
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# get embedding model instance
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()
@ -1591,10 +1597,12 @@ class SegmentService:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
VectorService.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, True)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form == IndexType.PARAGRAPH_INDEX or document.doc_form == IndexType.QA_INDEX:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
@ -1626,14 +1634,13 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
index_node_ids = (
DocumentSegment.query.with_entities(DocumentSegment.index_node_id)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.tenant_id == current_user.current_tenant_id,
)
.all()
)
@ -1642,17 +1649,20 @@ class SegmentService:
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete()
db.session.commit()
@classmethod
def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document):
if action == "enable":
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == False,
).all()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == False,
)
.all()
)
if not segments:
return
real_deal_segmment_ids = []
@ -1670,12 +1680,16 @@ class SegmentService:
enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
elif action == "disable":
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True,
).all()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True,
)
.all()
)
if not segments:
return
real_deal_segmment_ids = []
@ -1696,18 +1710,24 @@ class SegmentService:
raise InvalidActionError()
@classmethod
def create_child_chunk(cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset) -> ChildChunk:
def create_child_chunk(
cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset
) -> ChildChunk:
lock_name = "add_child_lock_{}".format(segment.id)
with redis_client.lock(lock_name, timeout=20):
index_node_id = str(uuid.uuid4())
index_node_hash = helper.generate_text_hash(content)
child_chunk_count = db.session.query(ChildChunk).filter(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
).count()
child_chunk= ChildChunk(
child_chunk_count = (
db.session.query(ChildChunk)
.filter(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
)
.count()
)
child_chunk = ChildChunk(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -1731,14 +1751,24 @@ class SegmentService:
db.session.commit()
return child_chunk
@classmethod
def update_child_chunk(cls, child_chunks_update_args: list[ChildChunkUpdateArgs], segment: DocumentSegment, document: Document, dataset: Dataset) -> list[ChildChunk]:
child_chunks = db.session.query(ChildChunk).filter(
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
).all()
def update_child_chunk(
cls,
child_chunks_update_args: list[ChildChunkUpdateArgs],
segment: DocumentSegment,
document: Document,
dataset: Dataset,
) -> list[ChildChunk]:
child_chunks = (
db.session.query(ChildChunk)
.filter(
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
)
.all()
)
child_chunks_map = {chunk.id: chunk for chunk in child_chunks}
new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], []
@ -1750,6 +1780,9 @@ class SegmentService:
if child_chunk.content != child_chunk_update_args.content:
child_chunk.content = child_chunk_update_args.content
child_chunk.word_count = len(child_chunk.content)
child_chunk.updated_by = current_user.id
child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
child_chunk.type = "customized"
update_child_chunks.append(child_chunk)
else:
new_child_chunks_args.append(child_chunk_update_args)
@ -1761,13 +1794,13 @@ class SegmentService:
if delete_child_chunks:
for child_chunk in delete_child_chunks:
db.session.delete(child_chunk)
db.session.delete(child_chunk)
if new_child_chunks_args:
child_chunk_count = len(child_chunks)
for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1):
index_node_id = str(uuid.uuid4())
index_node_hash = helper.generate_text_hash(args.content)
child_chunk= ChildChunk(
child_chunk = ChildChunk(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -1791,7 +1824,7 @@ class SegmentService:
db.session.rollback()
raise ChildChunkIndexingError(str(e))
return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position)
@classmethod
def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset):
db.session.delete(child_chunk)
@ -1802,15 +1835,20 @@ class SegmentService:
db.session.rollback()
raise ChildChunkDeleteIndexError(str(e))
db.session.commit()
@classmethod
def get_child_chunks(cls, segment_id: str, document_id: str, dataset_id: str):
return db.session.query(ChildChunk).filter(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset_id,
ChildChunk.document_id == document_id,
ChildChunk.segment_id == segment_id,
).all()
return (
db.session.query(ChildChunk)
.filter(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset_id,
ChildChunk.document_id == document_id,
ChildChunk.segment_id == segment_id,
)
.all()
)
class DatasetCollectionBindingService:
@classmethod

View File

@ -1,7 +1,7 @@
from typing import Literal, Optional, Union
from enum import Enum
from typing import Literal, Optional
from pydantic import BaseModel
from enum import Enum
class ParentMode(str, Enum):

View File

@ -4,5 +4,6 @@ from services.errors.base import BaseServiceError
class ChildChunkIndexingError(BaseServiceError):
description = "{message}"
class ChildChunkDeleteIndexError(BaseServiceError):
description = "{message}"

View File

@ -6,7 +6,7 @@ from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.account import Account
from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment, Document as DatasetDocument
from models.dataset import Dataset, DatasetQuery
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,

View File

@ -1,16 +1,14 @@
from typing import Optional
from core.errors.error import LLMBadRequestError
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, ChildChunk, Document as DatasetDocument
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
class VectorService:
@ -20,7 +18,7 @@ class VectorService:
):
documents = []
if doc_form == IndexType.PARENT_CHILD_INDEX:
# get embedding model instance
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()
@ -43,20 +41,19 @@ class VectorService:
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
.first()
)
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
else:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
@ -95,14 +92,20 @@ class VectorService:
keyword.add_texts([document])
@classmethod
def generate_child_chunks(cls, segment: DocumentSegment, dataset_document: Document, dataset: Dataset,
embedding_model_instance: ModelInstance, processing_rule: DatasetProcessRule,
regenerate: bool = False):
def generate_child_chunks(
cls,
segment: DocumentSegment,
dataset_document: Document,
dataset: Dataset,
embedding_model_instance: ModelInstance,
processing_rule: DatasetProcessRule,
regenerate: bool = False,
):
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
if regenerate:
# delete child chunks
index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)
# generate child chunks
document = Document(
@ -158,9 +161,15 @@ class VectorService:
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts([child_document], duplicate_check=True)
@classmethod
def update_child_chunk_vector(cls, new_child_chunks: list[ChildChunk], update_child_chunks: list[ChildChunk], delete_child_chunks: list[ChildChunk], dataset: Dataset):
def update_child_chunk_vector(
cls,
new_child_chunks: list[ChildChunk],
update_child_chunks: list[ChildChunk],
delete_child_chunks: list[ChildChunk],
dataset: Dataset,
):
documents = []
delete_node_ids = []
for new_child_chunk in new_child_chunks:

View File

@ -7,7 +7,6 @@ import click
from celery import shared_task
from sqlalchemy import func
from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db

View File

@ -4,11 +4,9 @@ import time
import click
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import ChildChunk, Dataset, Document
from models.dataset import Dataset, Document
@shared_task(queue="dataset")
@ -40,9 +38,6 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
end_at = time.perf_counter()
logging.info(
click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green")
)
logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green"))
except Exception:
logging.exception("delete segment from index failed")

View File

@ -3,12 +3,12 @@ import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment, Document as DatasetDocument
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset")
@ -37,24 +37,25 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id
).all()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
if not segments:
return
try:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(
click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green")
)
logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green"))
except Exception as e:
logging.exception("remove segments from index failed:{}".format(e))
# update segment error msg
@ -64,14 +65,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at":None,
"disabled_by":None,
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
indexing_cache_key = "segment_{}_indexing".format(segment.id)
redis_client.delete(indexing_cache_key)

View File

@ -4,14 +4,14 @@ import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment, Dataset, Document as DatasetDocument
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset")
@ -39,11 +39,15 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).all()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
if not segments:
return
@ -58,7 +62,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
"document_id": document_id,
"dataset_id": dataset_id,
},
)
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
@ -81,9 +85,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
index_processor.load(dataset, documents)
end_at = time.perf_counter()
logging.info(
click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green")
)
logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green"))
except Exception as e:
logging.exception("enable segments to index failed")
# update segment error msg
@ -102,5 +104,5 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
indexing_cache_key = "segment_{}_indexing".format(segment.id)
redis_client.delete(indexing_cache_key)