mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
calculate tokens
This commit is contained in:
parent
63e34e5227
commit
ea5e8ee7cc
|
@ -349,7 +349,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
|
|
||||||
if document.indexing_status in ['completed', 'error']:
|
if document.indexing_status in ['completed', 'error']:
|
||||||
raise DocumentAlreadyFinishedError()
|
indexing_runner.calculate_tokens(document)
|
||||||
|
|
||||||
data_process_rule = document.dataset_process_rule
|
data_process_rule = document.dataset_process_rule
|
||||||
data_process_rule_dict = data_process_rule.to_dict()
|
data_process_rule_dict = data_process_rule.to_dict()
|
||||||
|
|
|
@ -214,6 +214,61 @@ class IndexingRunner:
|
||||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
def calculate_tokens(self, tenant_id: str, tokens: int, dataset_id: str = None,
|
||||||
|
indexing_technique: str = 'economy') -> dict:
|
||||||
|
"""
|
||||||
|
Estimate the indexing for the document.
|
||||||
|
"""
|
||||||
|
embedding_model_instance = None
|
||||||
|
if dataset_id:
|
||||||
|
dataset = Dataset.query.filter_by(
|
||||||
|
id=dataset_id
|
||||||
|
).first()
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError('Dataset not found.')
|
||||||
|
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||||
|
if dataset.embedding_model_provider:
|
||||||
|
embedding_model_instance = self.model_manager.get_model_instance(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=dataset.embedding_model_provider,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=dataset.embedding_model
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if indexing_technique == 'high_quality':
|
||||||
|
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
)
|
||||||
|
preview_texts = []
|
||||||
|
total_segments = 0
|
||||||
|
total_price = 0
|
||||||
|
currency = 'USD'
|
||||||
|
if embedding_model_instance:
|
||||||
|
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
|
||||||
|
embedding_price_info = embedding_model_type_instance.get_price(
|
||||||
|
model=embedding_model_instance.model,
|
||||||
|
credentials=embedding_model_instance.credentials,
|
||||||
|
price_type=PriceType.INPUT,
|
||||||
|
tokens=tokens
|
||||||
|
)
|
||||||
|
total_price = '{:f}'.format(embedding_price_info.total_amount)
|
||||||
|
currency = embedding_price_info.currency
|
||||||
|
return {
|
||||||
|
"total_segments": total_segments,
|
||||||
|
"tokens": tokens,
|
||||||
|
"total_price": total_price,
|
||||||
|
"currency": currency,
|
||||||
|
"preview": preview_texts
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
|
def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
|
||||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||||
indexing_technique: str = 'economy') -> dict:
|
indexing_technique: str = 'economy') -> dict:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user