mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat: added dataset recall testing API (#9300)
This commit is contained in:
parent
5c7b1358d4
commit
8501af298f
|
@ -1,88 +1,24 @@
|
|||
import logging
|
||||
from flask_restful import Resource
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class HitTestingApi(Resource):
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
logging.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||
|
|
85
api/controllers/console/datasets/hit_testing_base.py
Normal file
85
api/controllers/console/datasets/hit_testing_base.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def hit_testing_args_check(args):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
logging.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
|
@ -5,7 +5,6 @@ from libs.external_api import ExternalApi
|
|||
bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import index
|
||||
from .app import app, audio, completion, conversation, file, message, workflow
|
||||
from .dataset import dataset, document, segment
|
||||
from .dataset import dataset, document, hit_testing, segment
|
||||
|
|
17
api/controllers/service_api/dataset/hit_testing.py
Normal file
17
api/controllers/service_api/dataset/hit_testing.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
|
||||
|
||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
|
@ -1050,6 +1050,151 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/hit_testing'
|
||||
method='POST'
|
||||
title='Dataset hit testing'
|
||||
name='#dataset_hit_testing'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
### Path
|
||||
<Properties>
|
||||
<Property name='dataset_id' type='string' key='dataset_id'>
|
||||
Dataset ID
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### Request Body
|
||||
<Properties>
|
||||
<Property name='query' type='string' key='query'>
|
||||
retrieval keywordc
|
||||
</Property>
|
||||
<Property name='retrieval_model' type='object' key='retrieval_model'>
|
||||
retrieval keyword(Optional, if not filled, it will be recalled according to the default method)
|
||||
- <code>search_method</code> (text) Search method: One of the following four keywords is required
|
||||
- <code>keyword_search</code> Keyword search
|
||||
- <code>semantic_search</code> Semantic search
|
||||
- <code>full_text_search</code> Full-text search
|
||||
- <code>hybrid_search</code> Hybrid search
|
||||
- <code>reranking_enable</code> (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search
|
||||
- <code>reranking_mode</code> (object) Rerank model configuration, optional, required if reranking is enabled
|
||||
- <code>reranking_provider_name</code> (string) Rerank model provider
|
||||
- <code>reranking_model_name</code> (string) Rerank model name
|
||||
- <code>weights</code> (double) Semantic search weight setting in hybrid search mode
|
||||
- <code>top_k</code> (integer) Number of results to return, optional
|
||||
- <code>score_threshold_enabled</code> (bool) Whether to enable score threshold
|
||||
- <code>score_threshold</code> (double) Score threshold
|
||||
</Property>
|
||||
<Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
|
||||
Unused field
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets/{dataset_id}/hit_testing"
|
||||
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
|
||||
"query": "test",
|
||||
"retrieval_model": {
|
||||
"search_method": "keyword_search",
|
||||
"reranking_enable": false,
|
||||
"reranking_mode": null,
|
||||
"reranking_model": {
|
||||
"reranking_provider_name": "",
|
||||
"reranking_model_name": ""
|
||||
},
|
||||
"weights": null,
|
||||
"top_k": 1,
|
||||
"score_threshold_enabled": false,
|
||||
"score_threshold": null
|
||||
}
|
||||
}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
|
||||
--header 'Authorization: Bearer {api_key}' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"query": "test",
|
||||
"retrieval_model": {
|
||||
"search_method": "keyword_search",
|
||||
"reranking_enable": false,
|
||||
"reranking_mode": null,
|
||||
"reranking_model": {
|
||||
"reranking_provider_name": "",
|
||||
"reranking_model_name": ""
|
||||
},
|
||||
"weights": null,
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": false,
|
||||
"score_threshold": null
|
||||
}
|
||||
}'
|
||||
```
|
||||
</CodeGroup>
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"query": {
|
||||
"content": "test"
|
||||
},
|
||||
"records": [
|
||||
{
|
||||
"segment": {
|
||||
"id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
|
||||
"position": 1,
|
||||
"document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
|
||||
"content": "Operation guide",
|
||||
"answer": null,
|
||||
"word_count": 847,
|
||||
"tokens": 280,
|
||||
"keywords": [
|
||||
"install",
|
||||
"java",
|
||||
"base",
|
||||
"scripts",
|
||||
"jdk",
|
||||
"manual",
|
||||
"internal",
|
||||
"opens",
|
||||
"add",
|
||||
"vmoptions"
|
||||
],
|
||||
"index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
|
||||
"index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
|
||||
"hit_count": 0,
|
||||
"enabled": true,
|
||||
"disabled_at": null,
|
||||
"disabled_by": null,
|
||||
"status": "completed",
|
||||
"created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
|
||||
"created_at": 1728734540,
|
||||
"indexing_at": 1728734552,
|
||||
"completed_at": 1728734584,
|
||||
"error": null,
|
||||
"stopped_at": null,
|
||||
"document": {
|
||||
"id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
|
||||
"data_source_type": "upload_file",
|
||||
"name": "readme.txt",
|
||||
"doc_type": null
|
||||
}
|
||||
},
|
||||
"score": 3.730463140527718e-05,
|
||||
"tsne_position": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
---
|
||||
|
||||
<Row>
|
||||
<Col>
|
||||
### Error message
|
||||
|
|
|
@ -1049,6 +1049,152 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/hit_testing'
|
||||
method='POST'
|
||||
title='知识库召回测试'
|
||||
name='#dataset_hit_testing'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
### Path
|
||||
<Properties>
|
||||
<Property name='dataset_id' type='string' key='dataset_id'>
|
||||
知识库 ID
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### Request Body
|
||||
<Properties>
|
||||
<Property name='query' type='string' key='query'>
|
||||
召回关键词
|
||||
</Property>
|
||||
<Property name='retrieval_model' type='object' key='retrieval_model'>
|
||||
召回参数(选填,如不填,按照默认方式召回)
|
||||
- <code>search_method</code> (text) 检索方法:以下三个关键字之一,必填
|
||||
- <code>keyword_search</code> 关键字检索
|
||||
- <code>semantic_search</code> 语义检索
|
||||
- <code>full_text_search</code> 全文检索
|
||||
- <code>hybrid_search</code> 混合检索
|
||||
- <code>reranking_enable</code> (bool) 是否启用 Reranking,非必填,如果检索模式为semantic_search模式或者hybrid_search则传值
|
||||
- <code>reranking_mode</code> (object) Rerank模型配置,非必填,如果启用了 reranking 则传值
|
||||
- <code>reranking_provider_name</code> (string) Rerank 模型提供商
|
||||
- <code>reranking_model_name</code> (string) Rerank 模型名称
|
||||
- <code>weights</code> (double) 混合检索模式下语意检索的权重设置
|
||||
- <code>top_k</code> (integer) 返回结果数量,非必填
|
||||
- <code>score_threshold_enabled</code> (bool) 是否开启Score阈值
|
||||
- <code>score_threshold</code> (double) Score阈值
|
||||
</Property>
|
||||
<Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
|
||||
未启用字段
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets/{dataset_id}/hit_testing"
|
||||
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
|
||||
"query": "test",
|
||||
"retrieval_model": {
|
||||
"search_method": "keyword_search",
|
||||
"reranking_enable": false,
|
||||
"reranking_mode": null,
|
||||
"reranking_model": {
|
||||
"reranking_provider_name": "",
|
||||
"reranking_model_name": ""
|
||||
},
|
||||
"weights": null,
|
||||
"top_k": 1,
|
||||
"score_threshold_enabled": false,
|
||||
"score_threshold": null
|
||||
}
|
||||
}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
|
||||
--header 'Authorization: Bearer {api_key}' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"query": "test",
|
||||
"retrieval_model": {
|
||||
"search_method": "keyword_search",
|
||||
"reranking_enable": false,
|
||||
"reranking_mode": null,
|
||||
"reranking_model": {
|
||||
"reranking_provider_name": "",
|
||||
"reranking_model_name": ""
|
||||
},
|
||||
"weights": null,
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": false,
|
||||
"score_threshold": null
|
||||
}
|
||||
}'
|
||||
```
|
||||
</CodeGroup>
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"query": {
|
||||
"content": "test"
|
||||
},
|
||||
"records": [
|
||||
{
|
||||
"segment": {
|
||||
"id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
|
||||
"position": 1,
|
||||
"document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
|
||||
"content": "Operation guide",
|
||||
"answer": null,
|
||||
"word_count": 847,
|
||||
"tokens": 280,
|
||||
"keywords": [
|
||||
"install",
|
||||
"java",
|
||||
"base",
|
||||
"scripts",
|
||||
"jdk",
|
||||
"manual",
|
||||
"internal",
|
||||
"opens",
|
||||
"add",
|
||||
"vmoptions"
|
||||
],
|
||||
"index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
|
||||
"index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
|
||||
"hit_count": 0,
|
||||
"enabled": true,
|
||||
"disabled_at": null,
|
||||
"disabled_by": null,
|
||||
"status": "completed",
|
||||
"created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
|
||||
"created_at": 1728734540,
|
||||
"indexing_at": 1728734552,
|
||||
"completed_at": 1728734584,
|
||||
"error": null,
|
||||
"stopped_at": null,
|
||||
"document": {
|
||||
"id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
|
||||
"data_source_type": "upload_file",
|
||||
"name": "readme.txt",
|
||||
"doc_type": null
|
||||
}
|
||||
},
|
||||
"score": 3.730463140527718e-05,
|
||||
"tsne_position": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
|
||||
---
|
||||
|
||||
<Row>
|
||||
|
|
Loading…
Reference in New Issue
Block a user