mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
add tidb on qdrant type (#9831)
Co-authored-by: Zhaofeng Miao <522856232@qq.com>
This commit is contained in:
parent
fc2297a2ca
commit
18106a4fc6
|
@ -571,6 +571,11 @@ class DataSetConfig(BaseSettings):
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TIDB_SERVERLESS_NUMBER: PositiveInt = Field(
|
||||||
|
description="number of tidb serverless cluster",
|
||||||
|
default=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceConfig(BaseSettings):
|
class WorkspaceConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -27,6 +27,7 @@ from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
|
||||||
from configs.middleware.vdb.qdrant_config import QdrantConfig
|
from configs.middleware.vdb.qdrant_config import QdrantConfig
|
||||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||||
|
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||||
from configs.middleware.vdb.upstash_config import UpstashConfig
|
from configs.middleware.vdb.upstash_config import UpstashConfig
|
||||||
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||||
|
@ -54,6 +55,11 @@ class VectorStoreConfig(BaseSettings):
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field(
|
||||||
|
description="Enable whitelist for vector store.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KeywordStoreConfig(BaseSettings):
|
class KeywordStoreConfig(BaseSettings):
|
||||||
KEYWORD_STORE: str = Field(
|
KEYWORD_STORE: str = Field(
|
||||||
|
@ -248,5 +254,6 @@ class MiddlewareConfig(
|
||||||
InternalTestConfig,
|
InternalTestConfig,
|
||||||
VikingDBConfig,
|
VikingDBConfig,
|
||||||
UpstashConfig,
|
UpstashConfig,
|
||||||
|
TidbOnQdrantConfig,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
65
api/configs/middleware/vdb/tidb_on_qdrant_config.py
Normal file
65
api/configs/middleware/vdb/tidb_on_qdrant_config.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class TidbOnQdrantConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Tidb on Qdrant configs
|
||||||
|
"""
|
||||||
|
|
||||||
|
TIDB_ON_QDRANT_URL: Optional[str] = Field(
|
||||||
|
description="Tidb on Qdrant url",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TIDB_ON_QDRANT_API_KEY: Optional[str] = Field(
|
||||||
|
description="Tidb on Qdrant api key",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
|
||||||
|
description="Tidb on Qdrant client timeout in seconds",
|
||||||
|
default=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field(
|
||||||
|
description="whether enable grpc support for Tidb on Qdrant connection",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field(
|
||||||
|
description="Tidb on Qdrant grpc port",
|
||||||
|
default=6334,
|
||||||
|
)
|
||||||
|
|
||||||
|
TIDB_PUBLIC_KEY: Optional[str] = Field(
|
||||||
|
description="Tidb account public key",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TIDB_PRIVATE_KEY: Optional[str] = Field(
|
||||||
|
description="Tidb account private key",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TIDB_API_URL: Optional[str] = Field(
|
||||||
|
description="Tidb API url",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TIDB_IAM_API_URL: Optional[str] = Field(
|
||||||
|
description="Tidb IAM API url",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TIDB_REGION: Optional[str] = Field(
|
||||||
|
description="Tidb serverless region",
|
||||||
|
default="regions/aws-us-east-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
TIDB_PROJECT_ID: Optional[str] = Field(
|
||||||
|
description="Tidb project id",
|
||||||
|
default=None,
|
||||||
|
)
|
|
@ -639,6 +639,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||||
| VectorType.ORACLE
|
| VectorType.ORACLE
|
||||||
| VectorType.ELASTICSEARCH
|
| VectorType.ELASTICSEARCH
|
||||||
| VectorType.PGVECTOR
|
| VectorType.PGVECTOR
|
||||||
|
| VectorType.TIDB_ON_QDRANT
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"retrieval_method": [
|
"retrieval_method": [
|
||||||
|
|
17
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py
Normal file
17
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ClusterEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Model Config Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
cluster_id: str
|
||||||
|
displayName: str
|
||||||
|
region: str
|
||||||
|
spendingLimit: Optional[int] = 1000
|
||||||
|
version: str
|
||||||
|
createdBy: str
|
|
@ -0,0 +1,526 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator, Iterable, Sequence
|
||||||
|
from itertools import islice
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||||
|
|
||||||
|
import qdrant_client
|
||||||
|
import requests
|
||||||
|
from flask import current_app
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from qdrant_client.http import models as rest
|
||||||
|
from qdrant_client.http.models import (
|
||||||
|
FilterSelector,
|
||||||
|
HnswConfigDiff,
|
||||||
|
PayloadSchemaType,
|
||||||
|
TextIndexParams,
|
||||||
|
TextIndexType,
|
||||||
|
TokenizerType,
|
||||||
|
)
|
||||||
|
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||||
|
from requests.auth import HTTPDigestAuth
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.rag.datasource.vdb.field import Field
|
||||||
|
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||||
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset, TidbAuthBinding
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from qdrant_client import grpc # noqa
|
||||||
|
from qdrant_client.conversions import common_types
|
||||||
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
|
DictFilter = dict[str, Union[str, int, bool, dict, list]]
|
||||||
|
MetadataFilter = Union[DictFilter, common_types.Filter]
|
||||||
|
|
||||||
|
|
||||||
|
class TidbOnQdrantConfig(BaseModel):
|
||||||
|
endpoint: str
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
timeout: float = 20
|
||||||
|
root_path: Optional[str] = None
|
||||||
|
grpc_port: int = 6334
|
||||||
|
prefer_grpc: bool = False
|
||||||
|
|
||||||
|
def to_qdrant_params(self):
|
||||||
|
if self.endpoint and self.endpoint.startswith("path:"):
|
||||||
|
path = self.endpoint.replace("path:", "")
|
||||||
|
if not os.path.isabs(path):
|
||||||
|
path = os.path.join(self.root_path, path)
|
||||||
|
|
||||||
|
return {"path": path}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"url": self.endpoint,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"verify": False,
|
||||||
|
"grpc_port": self.grpc_port,
|
||||||
|
"prefer_grpc": self.prefer_grpc,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TidbConfig(BaseModel):
|
||||||
|
api_url: str
|
||||||
|
public_key: str
|
||||||
|
private_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class TidbOnQdrantVector(BaseVector):
|
||||||
|
def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"):
|
||||||
|
super().__init__(collection_name)
|
||||||
|
self._client_config = config
|
||||||
|
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
|
||||||
|
self._distance_func = distance_func.upper()
|
||||||
|
self._group_id = group_id
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return VectorType.TIDB_ON_QDRANT
|
||||||
|
|
||||||
|
def to_index_struct(self) -> dict:
|
||||||
|
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||||
|
|
||||||
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
if texts:
|
||||||
|
# get embedding vector size
|
||||||
|
vector_size = len(embeddings[0])
|
||||||
|
# get collection name
|
||||||
|
collection_name = self._collection_name
|
||||||
|
# create collection
|
||||||
|
self.create_collection(collection_name, vector_size)
|
||||||
|
|
||||||
|
self.add_texts(texts, embeddings, **kwargs)
|
||||||
|
|
||||||
|
def create_collection(self, collection_name: str, vector_size: int):
|
||||||
|
lock_name = "vector_indexing_lock_{}".format(collection_name)
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
return
|
||||||
|
collection_name = collection_name or uuid.uuid4().hex
|
||||||
|
all_collection_name = []
|
||||||
|
collections_response = self._client.get_collections()
|
||||||
|
collection_list = collections_response.collections
|
||||||
|
for collection in collection_list:
|
||||||
|
all_collection_name.append(collection.name)
|
||||||
|
if collection_name not in all_collection_name:
|
||||||
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
|
vectors_config = rest.VectorParams(
|
||||||
|
size=vector_size,
|
||||||
|
distance=rest.Distance[self._distance_func],
|
||||||
|
)
|
||||||
|
hnsw_config = HnswConfigDiff(
|
||||||
|
m=0,
|
||||||
|
payload_m=16,
|
||||||
|
ef_construct=100,
|
||||||
|
full_scan_threshold=10000,
|
||||||
|
max_indexing_threads=0,
|
||||||
|
on_disk=False,
|
||||||
|
)
|
||||||
|
self._client.recreate_collection(
|
||||||
|
collection_name=collection_name,
|
||||||
|
vectors_config=vectors_config,
|
||||||
|
hnsw_config=hnsw_config,
|
||||||
|
timeout=int(self._client_config.timeout),
|
||||||
|
)
|
||||||
|
|
||||||
|
# create group_id payload index
|
||||||
|
self._client.create_payload_index(
|
||||||
|
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
|
||||||
|
)
|
||||||
|
# create doc_id payload index
|
||||||
|
self._client.create_payload_index(
|
||||||
|
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||||
|
)
|
||||||
|
# create full text index
|
||||||
|
text_index_params = TextIndexParams(
|
||||||
|
type=TextIndexType.TEXT,
|
||||||
|
tokenizer=TokenizerType.MULTILINGUAL,
|
||||||
|
min_token_len=2,
|
||||||
|
max_token_len=20,
|
||||||
|
lowercase=True,
|
||||||
|
)
|
||||||
|
self._client.create_payload_index(
|
||||||
|
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
|
||||||
|
)
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
uuids = self._get_uuids(documents)
|
||||||
|
texts = [d.page_content for d in documents]
|
||||||
|
metadatas = [d.metadata for d in documents]
|
||||||
|
|
||||||
|
added_ids = []
|
||||||
|
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
|
||||||
|
self._client.upsert(collection_name=self._collection_name, points=points)
|
||||||
|
added_ids.extend(batch_ids)
|
||||||
|
|
||||||
|
return added_ids
|
||||||
|
|
||||||
|
def _generate_rest_batches(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
embeddings: list[list[float]],
|
||||||
|
metadatas: Optional[list[dict]] = None,
|
||||||
|
ids: Optional[Sequence[str]] = None,
|
||||||
|
batch_size: int = 64,
|
||||||
|
group_id: Optional[str] = None,
|
||||||
|
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
|
||||||
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
|
texts_iterator = iter(texts)
|
||||||
|
embeddings_iterator = iter(embeddings)
|
||||||
|
metadatas_iterator = iter(metadatas or [])
|
||||||
|
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
||||||
|
while batch_texts := list(islice(texts_iterator, batch_size)):
|
||||||
|
# Take the corresponding metadata and id for each text in a batch
|
||||||
|
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
||||||
|
batch_ids = list(islice(ids_iterator, batch_size))
|
||||||
|
|
||||||
|
# Generate the embeddings for all the texts in a batch
|
||||||
|
batch_embeddings = list(islice(embeddings_iterator, batch_size))
|
||||||
|
|
||||||
|
points = [
|
||||||
|
rest.PointStruct(
|
||||||
|
id=point_id,
|
||||||
|
vector=vector,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
for point_id, vector, payload in zip(
|
||||||
|
batch_ids,
|
||||||
|
batch_embeddings,
|
||||||
|
self._build_payloads(
|
||||||
|
batch_texts,
|
||||||
|
batch_metadatas,
|
||||||
|
Field.CONTENT_KEY.value,
|
||||||
|
Field.METADATA_KEY.value,
|
||||||
|
group_id,
|
||||||
|
Field.GROUP_KEY.value,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
yield batch_ids, points
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_payloads(
|
||||||
|
cls,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[list[dict]],
|
||||||
|
content_payload_key: str,
|
||||||
|
metadata_payload_key: str,
|
||||||
|
group_id: str,
|
||||||
|
group_payload_key: str,
|
||||||
|
) -> list[dict]:
|
||||||
|
payloads = []
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
if text is None:
|
||||||
|
raise ValueError(
|
||||||
|
"At least one of the texts is None. Please remove it before "
|
||||||
|
"calling .from_texts or .add_texts on Qdrant instance."
|
||||||
|
)
|
||||||
|
metadata = metadatas[i] if metadatas is not None else None
|
||||||
|
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
|
||||||
|
|
||||||
|
return payloads
|
||||||
|
|
||||||
|
def delete_by_metadata_field(self, key: str, value: str):
|
||||||
|
from qdrant_client.http import models
|
||||||
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
|
|
||||||
|
try:
|
||||||
|
filter = models.Filter(
|
||||||
|
must=[
|
||||||
|
models.FieldCondition(
|
||||||
|
key=f"metadata.{key}",
|
||||||
|
match=models.MatchValue(value=value),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._reload_if_needed()
|
||||||
|
|
||||||
|
self._client.delete(
|
||||||
|
collection_name=self._collection_name,
|
||||||
|
points_selector=FilterSelector(filter=filter),
|
||||||
|
)
|
||||||
|
except UnexpectedResponse as e:
|
||||||
|
# Collection does not exist, so return
|
||||||
|
if e.status_code == 404:
|
||||||
|
return
|
||||||
|
# Some other error occurred, so re-raise the exception
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._client.delete_collection(collection_name=self._collection_name)
|
||||||
|
except UnexpectedResponse as e:
|
||||||
|
# Collection does not exist, so return
|
||||||
|
if e.status_code == 404:
|
||||||
|
return
|
||||||
|
# Some other error occurred, so re-raise the exception
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
from qdrant_client.http import models
|
||||||
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
|
|
||||||
|
for node_id in ids:
|
||||||
|
try:
|
||||||
|
filter = models.Filter(
|
||||||
|
must=[
|
||||||
|
models.FieldCondition(
|
||||||
|
key="metadata.doc_id",
|
||||||
|
match=models.MatchValue(value=node_id),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self._client.delete(
|
||||||
|
collection_name=self._collection_name,
|
||||||
|
points_selector=FilterSelector(filter=filter),
|
||||||
|
)
|
||||||
|
except UnexpectedResponse as e:
|
||||||
|
# Collection does not exist, so return
|
||||||
|
if e.status_code == 404:
|
||||||
|
return
|
||||||
|
# Some other error occurred, so re-raise the exception
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def text_exists(self, id: str) -> bool:
|
||||||
|
all_collection_name = []
|
||||||
|
collections_response = self._client.get_collections()
|
||||||
|
collection_list = collections_response.collections
|
||||||
|
for collection in collection_list:
|
||||||
|
all_collection_name.append(collection.name)
|
||||||
|
if self._collection_name not in all_collection_name:
|
||||||
|
return False
|
||||||
|
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
|
||||||
|
|
||||||
|
return len(response) > 0
|
||||||
|
|
||||||
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
from qdrant_client.http import models
|
||||||
|
|
||||||
|
filter = models.Filter(
|
||||||
|
must=[
|
||||||
|
models.FieldCondition(
|
||||||
|
key="group_id",
|
||||||
|
match=models.MatchValue(value=self._group_id),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
results = self._client.search(
|
||||||
|
collection_name=self._collection_name,
|
||||||
|
query_vector=query_vector,
|
||||||
|
query_filter=filter,
|
||||||
|
limit=kwargs.get("top_k", 4),
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=True,
|
||||||
|
score_threshold=kwargs.get("score_threshold", 0.0),
|
||||||
|
)
|
||||||
|
docs = []
|
||||||
|
for result in results:
|
||||||
|
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||||
|
# duplicate check score threshold
|
||||||
|
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||||
|
if result.score > score_threshold:
|
||||||
|
metadata["score"] = result.score
|
||||||
|
doc = Document(
|
||||||
|
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
docs.append(doc)
|
||||||
|
# Sort the documents by score in descending order
|
||||||
|
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
|
"""Return docs most similar by bm25.
|
||||||
|
Returns:
|
||||||
|
List of documents most similar to the query text and distance for each.
|
||||||
|
"""
|
||||||
|
from qdrant_client.http import models
|
||||||
|
|
||||||
|
scroll_filter = models.Filter(
|
||||||
|
must=[
|
||||||
|
models.FieldCondition(
|
||||||
|
key="page_content",
|
||||||
|
match=models.MatchText(text=query),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
response = self._client.scroll(
|
||||||
|
collection_name=self._collection_name,
|
||||||
|
scroll_filter=scroll_filter,
|
||||||
|
limit=kwargs.get("top_k", 2),
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=True,
|
||||||
|
)
|
||||||
|
results = response[0]
|
||||||
|
documents = []
|
||||||
|
for result in results:
|
||||||
|
if result:
|
||||||
|
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
|
||||||
|
document.metadata["vector"] = result.vector
|
||||||
|
documents.append(document)
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def _reload_if_needed(self):
|
||||||
|
if isinstance(self._client, QdrantLocal):
|
||||||
|
self._client = cast(QdrantLocal, self._client)
|
||||||
|
self._client._load()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _document_from_scored_point(
|
||||||
|
cls,
|
||||||
|
scored_point: Any,
|
||||||
|
content_payload_key: str,
|
||||||
|
metadata_payload_key: str,
|
||||||
|
) -> Document:
|
||||||
|
return Document(
|
||||||
|
page_content=scored_point.payload.get(content_payload_key),
|
||||||
|
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
||||||
|
tidb_auth_binding = (
|
||||||
|
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||||
|
)
|
||||||
|
if not tidb_auth_binding:
|
||||||
|
idle_tidb_auth_binding = (
|
||||||
|
db.session.query(TidbAuthBinding)
|
||||||
|
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||||
|
.limit(1)
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
if idle_tidb_auth_binding:
|
||||||
|
idle_tidb_auth_binding.active = True
|
||||||
|
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||||
|
db.session.commit()
|
||||||
|
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||||
|
else:
|
||||||
|
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||||
|
tidb_auth_binding = (
|
||||||
|
db.session.query(TidbAuthBinding)
|
||||||
|
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
if tidb_auth_binding:
|
||||||
|
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||||
|
|
||||||
|
else:
|
||||||
|
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||||
|
dify_config.TIDB_PROJECT_ID,
|
||||||
|
dify_config.TIDB_API_URL,
|
||||||
|
dify_config.TIDB_IAM_API_URL,
|
||||||
|
dify_config.TIDB_PUBLIC_KEY,
|
||||||
|
dify_config.TIDB_PRIVATE_KEY,
|
||||||
|
dify_config.TIDB_REGION,
|
||||||
|
)
|
||||||
|
new_tidb_auth_binding = TidbAuthBinding(
|
||||||
|
cluster_id=new_cluster["cluster_id"],
|
||||||
|
cluster_name=new_cluster["cluster_name"],
|
||||||
|
account=new_cluster["account"],
|
||||||
|
password=new_cluster["password"],
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
active=True,
|
||||||
|
status="ACTIVE",
|
||||||
|
)
|
||||||
|
db.session.add(new_tidb_auth_binding)
|
||||||
|
db.session.commit()
|
||||||
|
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||||
|
|
||||||
|
else:
|
||||||
|
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||||
|
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
|
collection_name = class_prefix
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name))
|
||||||
|
|
||||||
|
config = current_app.config
|
||||||
|
|
||||||
|
return TidbOnQdrantVector(
|
||||||
|
collection_name=collection_name,
|
||||||
|
group_id=dataset.id,
|
||||||
|
config=TidbOnQdrantConfig(
|
||||||
|
endpoint=dify_config.TIDB_ON_QDRANT_URL,
|
||||||
|
api_key=TIDB_ON_QDRANT_API_KEY,
|
||||||
|
root_path=config.root_path,
|
||||||
|
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
|
||||||
|
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
|
||||||
|
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str):
|
||||||
|
"""
|
||||||
|
Creates a new TiDB Serverless cluster.
|
||||||
|
:param tidb_config: The configuration for the TiDB Cloud API.
|
||||||
|
:param display_name: The user-friendly display name of the cluster (required).
|
||||||
|
:param region: The region where the cluster will be created (required).
|
||||||
|
|
||||||
|
:return: The response from the API.
|
||||||
|
"""
|
||||||
|
region_object = {
|
||||||
|
"name": region,
|
||||||
|
}
|
||||||
|
|
||||||
|
labels = {
|
||||||
|
"tidb.cloud/project": "1372813089454548012",
|
||||||
|
}
|
||||||
|
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{tidb_config.api_url}/clusters",
|
||||||
|
json=cluster_data,
|
||||||
|
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str):
|
||||||
|
"""
|
||||||
|
Changes the root password of a specific TiDB Serverless cluster.
|
||||||
|
|
||||||
|
:param tidb_config: The configuration for the TiDB Cloud API.
|
||||||
|
:param cluster_id: The ID of the cluster for which the password is to be changed (required).
|
||||||
|
:param new_password: The new password for the root user (required).
|
||||||
|
:return: The response from the API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
body = {"password": new_password}
|
||||||
|
|
||||||
|
response = requests.put(
|
||||||
|
f"{tidb_config.api_url}/clusters/{cluster_id}/password",
|
||||||
|
json=body,
|
||||||
|
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
250
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
Normal file
250
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
Normal file
|
@ -0,0 +1,250 @@
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from requests.auth import HTTPDigestAuth
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import TidbAuthBinding
|
||||||
|
|
||||||
|
|
||||||
|
class TidbService:
|
||||||
|
@staticmethod
|
||||||
|
def create_tidb_serverless_cluster(
|
||||||
|
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a new TiDB Serverless cluster.
|
||||||
|
:param project_id: The project ID of the TiDB Cloud project (required).
|
||||||
|
:param api_url: The URL of the TiDB Cloud API (required).
|
||||||
|
:param iam_url: The URL of the TiDB Cloud IAM API (required).
|
||||||
|
:param public_key: The public key for the API (required).
|
||||||
|
:param private_key: The private key for the API (required).
|
||||||
|
:param display_name: The user-friendly display name of the cluster (required).
|
||||||
|
:param region: The region where the cluster will be created (required).
|
||||||
|
|
||||||
|
:return: The response from the API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
region_object = {
|
||||||
|
"name": region,
|
||||||
|
}
|
||||||
|
|
||||||
|
labels = {
|
||||||
|
"tidb.cloud/project": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
spending_limit = {
|
||||||
|
"monthly": 100,
|
||||||
|
}
|
||||||
|
password = str(uuid.uuid4()).replace("-", "")[:16]
|
||||||
|
display_name = str(uuid.uuid4()).replace("-", "")[:16]
|
||||||
|
cluster_data = {
|
||||||
|
"displayName": display_name,
|
||||||
|
"region": region_object,
|
||||||
|
"labels": labels,
|
||||||
|
"spendingLimit": spending_limit,
|
||||||
|
"rootPassword": password,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key))
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
cluster_id = response_data["clusterId"]
|
||||||
|
retry_count = 0
|
||||||
|
max_retries = 30
|
||||||
|
while retry_count < max_retries:
|
||||||
|
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
|
||||||
|
if cluster_response["state"] == "ACTIVE":
|
||||||
|
user_prefix = cluster_response["userPrefix"]
|
||||||
|
return {
|
||||||
|
"cluster_id": cluster_id,
|
||||||
|
"cluster_name": display_name,
|
||||||
|
"account": f"{user_prefix}.root",
|
||||||
|
"password": password,
|
||||||
|
}
|
||||||
|
time.sleep(30) # wait 30 seconds before retrying
|
||||||
|
retry_count += 1
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
|
||||||
|
"""
|
||||||
|
Deletes a specific TiDB Serverless cluster.
|
||||||
|
|
||||||
|
:param api_url: The URL of the TiDB Cloud API (required).
|
||||||
|
:param public_key: The public key for the API (required).
|
||||||
|
:param private_key: The private key for the API (required).
|
||||||
|
:param cluster_id: The ID of the cluster to be deleted (required).
|
||||||
|
:return: The response from the API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
|
||||||
|
"""
|
||||||
|
Deletes a specific TiDB Serverless cluster.
|
||||||
|
|
||||||
|
:param api_url: The URL of the TiDB Cloud API (required).
|
||||||
|
:param public_key: The public key for the API (required).
|
||||||
|
:param private_key: The private key for the API (required).
|
||||||
|
:param cluster_id: The ID of the cluster to be deleted (required).
|
||||||
|
:return: The response from the API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def change_tidb_serverless_root_password(
|
||||||
|
api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Changes the root password of a specific TiDB Serverless cluster.
|
||||||
|
|
||||||
|
:param api_url: The URL of the TiDB Cloud API (required).
|
||||||
|
:param public_key: The public key for the API (required).
|
||||||
|
:param private_key: The private key for the API (required).
|
||||||
|
:param cluster_id: The ID of the cluster for which the password is to be changed (required).+
|
||||||
|
:param account: The account for which the password is to be changed (required).
|
||||||
|
:param new_password: The new password for the root user (required).
|
||||||
|
:return: The response from the API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
|
||||||
|
|
||||||
|
response = requests.patch(
|
||||||
|
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
|
||||||
|
json=body,
|
||||||
|
auth=HTTPDigestAuth(public_key, private_key),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def batch_update_tidb_serverless_cluster_status(
|
||||||
|
tidb_serverless_list: list[TidbAuthBinding],
|
||||||
|
project_id: str,
|
||||||
|
api_url: str,
|
||||||
|
iam_url: str,
|
||||||
|
public_key: str,
|
||||||
|
private_key: str,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Update the status of a new TiDB Serverless cluster.
|
||||||
|
:param project_id: The project ID of the TiDB Cloud project (required).
|
||||||
|
:param api_url: The URL of the TiDB Cloud API (required).
|
||||||
|
:param iam_url: The URL of the TiDB Cloud IAM API (required).
|
||||||
|
:param public_key: The public key for the API (required).
|
||||||
|
:param private_key: The private key for the API (required).
|
||||||
|
:param display_name: The user-friendly display name of the cluster (required).
|
||||||
|
:param region: The region where the cluster will be created (required).
|
||||||
|
|
||||||
|
:return: The response from the API.
|
||||||
|
"""
|
||||||
|
clusters = []
|
||||||
|
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
|
||||||
|
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
|
||||||
|
params = {"clusterIds": cluster_ids, "view": "FULL"}
|
||||||
|
response = requests.get(
|
||||||
|
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
cluster_infos = []
|
||||||
|
for item in response_data["clusters"]:
|
||||||
|
state = item["state"]
|
||||||
|
userPrefix = item["userPrefix"]
|
||||||
|
if state == "ACTIVE" and len(userPrefix) > 0:
|
||||||
|
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
||||||
|
cluster_info.status = "ACTIVE"
|
||||||
|
cluster_info.account = f"{userPrefix}.root"
|
||||||
|
db.session.add(cluster_info)
|
||||||
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def batch_create_tidb_serverless_cluster(
|
||||||
|
batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Creates a new TiDB Serverless cluster.
|
||||||
|
:param project_id: The project ID of the TiDB Cloud project (required).
|
||||||
|
:param api_url: The URL of the TiDB Cloud API (required).
|
||||||
|
:param iam_url: The URL of the TiDB Cloud IAM API (required).
|
||||||
|
:param public_key: The public key for the API (required).
|
||||||
|
:param private_key: The private key for the API (required).
|
||||||
|
:param display_name: The user-friendly display name of the cluster (required).
|
||||||
|
:param region: The region where the cluster will be created (required).
|
||||||
|
|
||||||
|
:return: The response from the API.
|
||||||
|
"""
|
||||||
|
clusters = []
|
||||||
|
for _ in range(batch_size):
|
||||||
|
region_object = {
|
||||||
|
"name": region,
|
||||||
|
}
|
||||||
|
|
||||||
|
labels = {
|
||||||
|
"tidb.cloud/project": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
spending_limit = {
|
||||||
|
"monthly": 10,
|
||||||
|
}
|
||||||
|
password = str(uuid.uuid4()).replace("-", "")[:16]
|
||||||
|
display_name = str(uuid.uuid4()).replace("-", "")
|
||||||
|
cluster_data = {
|
||||||
|
"cluster": {
|
||||||
|
"displayName": display_name,
|
||||||
|
"region": region_object,
|
||||||
|
"labels": labels,
|
||||||
|
"spendingLimit": spending_limit,
|
||||||
|
"rootPassword": password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache_key = f"tidb_serverless_cluster_password:{display_name}"
|
||||||
|
redis_client.setex(cache_key, 3600, password)
|
||||||
|
clusters.append(cluster_data)
|
||||||
|
|
||||||
|
request_body = {"requests": clusters}
|
||||||
|
response = requests.post(
|
||||||
|
f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
cluster_infos = []
|
||||||
|
for item in response_data["clusters"]:
|
||||||
|
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
|
||||||
|
password = redis_client.get(cache_key)
|
||||||
|
if not password:
|
||||||
|
continue
|
||||||
|
cluster_info = {
|
||||||
|
"cluster_id": item["clusterId"],
|
||||||
|
"cluster_name": item["displayName"],
|
||||||
|
"account": "root",
|
||||||
|
"password": password.decode("utf-8"),
|
||||||
|
}
|
||||||
|
cluster_infos.append(cluster_info)
|
||||||
|
return cluster_infos
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
|
@ -9,8 +9,9 @@ from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset, Whitelist
|
||||||
|
|
||||||
|
|
||||||
class AbstractVectorFactory(ABC):
|
class AbstractVectorFactory(ABC):
|
||||||
|
@ -35,8 +36,18 @@ class Vector:
|
||||||
|
|
||||||
def _init_vector(self) -> BaseVector:
|
def _init_vector(self) -> BaseVector:
|
||||||
vector_type = dify_config.VECTOR_STORE
|
vector_type = dify_config.VECTOR_STORE
|
||||||
|
|
||||||
if self._dataset.index_struct_dict:
|
if self._dataset.index_struct_dict:
|
||||||
vector_type = self._dataset.index_struct_dict["type"]
|
vector_type = self._dataset.index_struct_dict["type"]
|
||||||
|
else:
|
||||||
|
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
|
||||||
|
whitelist = (
|
||||||
|
db.session.query(Whitelist)
|
||||||
|
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
if whitelist:
|
||||||
|
vector_type = VectorType.TIDB_ON_QDRANT
|
||||||
|
|
||||||
if not vector_type:
|
if not vector_type:
|
||||||
raise ValueError("Vector store must be specified.")
|
raise ValueError("Vector store must be specified.")
|
||||||
|
@ -115,6 +126,10 @@ class Vector:
|
||||||
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
|
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
|
||||||
|
|
||||||
return UpstashVectorFactory
|
return UpstashVectorFactory
|
||||||
|
case VectorType.TIDB_ON_QDRANT:
|
||||||
|
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
|
||||||
|
|
||||||
|
return TidbOnQdrantVectorFactory
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||||
|
|
||||||
|
|
|
@ -19,3 +19,4 @@ class VectorType(str, Enum):
|
||||||
BAIDU = "baidu"
|
BAIDU = "baidu"
|
||||||
VIKINGDB = "vikingdb"
|
VIKINGDB = "vikingdb"
|
||||||
UPSTASH = "upstash"
|
UPSTASH = "upstash"
|
||||||
|
TIDB_ON_QDRANT = "tidb_on_qdrant"
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from celery import Celery, Task
|
from celery import Celery, Task
|
||||||
|
from celery.schedules import crontab
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
@ -55,6 +56,8 @@ def init_app(app: Flask) -> Celery:
|
||||||
imports = [
|
imports = [
|
||||||
"schedule.clean_embedding_cache_task",
|
"schedule.clean_embedding_cache_task",
|
||||||
"schedule.clean_unused_datasets_task",
|
"schedule.clean_unused_datasets_task",
|
||||||
|
"schedule.create_tidb_serverless_task",
|
||||||
|
"schedule.update_tidb_serverless_status_task",
|
||||||
]
|
]
|
||||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||||
beat_schedule = {
|
beat_schedule = {
|
||||||
|
@ -66,6 +69,14 @@ def init_app(app: Flask) -> Celery:
|
||||||
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
|
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
|
||||||
"schedule": timedelta(days=day),
|
"schedule": timedelta(days=day),
|
||||||
},
|
},
|
||||||
|
"create_tidb_serverless_task": {
|
||||||
|
"task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task",
|
||||||
|
"schedule": crontab(minute="0", hour="*"),
|
||||||
|
},
|
||||||
|
"update_tidb_serverless_status_task": {
|
||||||
|
"task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task",
|
||||||
|
"schedule": crontab(minute="30", hour="*"),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
"""add-tidb-auth-binding
|
||||||
|
|
||||||
|
Revision ID: 0251a1c768cc
|
||||||
|
Revises: 63a83fcf12ba
|
||||||
|
Create Date: 2024-08-15 09:56:59.012490
|
||||||
|
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
import models as models
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '0251a1c768cc'
|
||||||
|
down_revision = 'bbadea11becb'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('tidb_auth_bindings',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column('cluster_id', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('cluster_name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('active', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('status', sa.String(length=255), server_default=sa.text("'CREATING'::character varying"), nullable=False),
|
||||||
|
sa.Column('account', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('password', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='tidb_auth_bindings_pkey')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('tidb_auth_bindings_active_idx', ['active'], unique=False)
|
||||||
|
batch_op.create_index('tidb_auth_bindings_status_idx', ['status'], unique=False)
|
||||||
|
batch_op.create_index('tidb_auth_bindings_created_at_idx', ['created_at'], unique=False)
|
||||||
|
batch_op.create_index('tidb_auth_bindings_tenant_idx', ['tenant_id'], unique=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('tidb_auth_bindings_tenant_idx')
|
||||||
|
batch_op.drop_index('tidb_auth_bindings_created_at_idx')
|
||||||
|
batch_op.drop_index('tidb_auth_bindings_active_idx')
|
||||||
|
batch_op.drop_index('tidb_auth_bindings_status_idx')
|
||||||
|
op.drop_table('tidb_auth_bindings')
|
||||||
|
# ### end Alembic commands ###
|
|
@ -0,0 +1,42 @@
|
||||||
|
"""add_white_list
|
||||||
|
|
||||||
|
Revision ID: 43fa78bc3b7d
|
||||||
|
Revises: 0251a1c768cc
|
||||||
|
Create Date: 2024-10-22 09:59:23.713716
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '43fa78bc3b7d'
|
||||||
|
down_revision = '0251a1c768cc'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('whitelists',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column('category', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='whitelists_pkey')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('whitelists', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('whitelists_tenant_idx', ['tenant_id'], unique=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
with op.batch_alter_table('whitelists', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('whitelists_tenant_idx')
|
||||||
|
|
||||||
|
op.drop_table('whitelists')
|
||||||
|
# ### end Alembic commands ###
|
|
@ -704,6 +704,38 @@ class DatasetCollectionBinding(db.Model):
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
|
class TidbAuthBinding(db.Model):
|
||||||
|
__tablename__ = "tidb_auth_bindings"
|
||||||
|
__table_args__ = (
|
||||||
|
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
|
||||||
|
db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
|
||||||
|
db.Index("tidb_auth_bindings_active_idx", "active"),
|
||||||
|
db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
|
||||||
|
db.Index("tidb_auth_bindings_status_idx", "status"),
|
||||||
|
)
|
||||||
|
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||||
|
tenant_id = db.Column(StringUUID, nullable=True)
|
||||||
|
cluster_id = db.Column(db.String(255), nullable=False)
|
||||||
|
cluster_name = db.Column(db.String(255), nullable=False)
|
||||||
|
active = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||||
|
status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))
|
||||||
|
account = db.Column(db.String(255), nullable=False)
|
||||||
|
password = db.Column(db.String(255), nullable=False)
|
||||||
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
|
class Whitelist(db.Model):
|
||||||
|
__tablename__ = "whitelists"
|
||||||
|
__table_args__ = (
|
||||||
|
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
|
||||||
|
db.Index("whitelists_tenant_idx", "tenant_id"),
|
||||||
|
)
|
||||||
|
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||||
|
tenant_id = db.Column(StringUUID, nullable=True)
|
||||||
|
category = db.Column(db.String(255), nullable=False)
|
||||||
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class DatasetPermission(db.Model):
|
class DatasetPermission(db.Model):
|
||||||
__tablename__ = "dataset_permissions"
|
__tablename__ = "dataset_permissions"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
|
|
56
api/schedule/create_tidb_serverless_task.py
Normal file
56
api/schedule/create_tidb_serverless_task.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
import app
|
||||||
|
from configs import dify_config
|
||||||
|
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import TidbAuthBinding
|
||||||
|
|
||||||
|
|
||||||
|
@app.celery.task(queue="dataset")
|
||||||
|
def create_tidb_serverless_task():
|
||||||
|
click.echo(click.style("Start create tidb serverless task.", fg="green"))
|
||||||
|
tidb_serverless_number = dify_config.TIDB_SERVERLESS_NUMBER
|
||||||
|
start_at = time.perf_counter()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# check the number of idle tidb serverless
|
||||||
|
idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count()
|
||||||
|
if idle_tidb_serverless_number >= tidb_serverless_number:
|
||||||
|
break
|
||||||
|
# create tidb serverless
|
||||||
|
iterations_per_thread = 20
|
||||||
|
create_clusters(iterations_per_thread)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(click.style(f"Error: {e}", fg="red"))
|
||||||
|
break
|
||||||
|
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
click.echo(click.style("Create tidb serverless task success latency: {}".format(end_at - start_at), fg="green"))
|
||||||
|
|
||||||
|
|
||||||
|
def create_clusters(batch_size):
|
||||||
|
try:
|
||||||
|
new_clusters = TidbService.batch_create_tidb_serverless_cluster(
|
||||||
|
batch_size,
|
||||||
|
dify_config.TIDB_PROJECT_ID,
|
||||||
|
dify_config.TIDB_API_URL,
|
||||||
|
dify_config.TIDB_IAM_API_URL,
|
||||||
|
dify_config.TIDB_PUBLIC_KEY,
|
||||||
|
dify_config.TIDB_PRIVATE_KEY,
|
||||||
|
dify_config.TIDB_REGION,
|
||||||
|
)
|
||||||
|
for new_cluster in new_clusters:
|
||||||
|
tidb_auth_binding = TidbAuthBinding(
|
||||||
|
cluster_id=new_cluster["cluster_id"],
|
||||||
|
cluster_name=new_cluster["cluster_name"],
|
||||||
|
account=new_cluster["account"],
|
||||||
|
password=new_cluster["password"],
|
||||||
|
)
|
||||||
|
db.session.add(tidb_auth_binding)
|
||||||
|
db.session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(click.style(f"Error: {e}", fg="red"))
|
51
api/schedule/update_tidb_serverless_status_task.py
Normal file
51
api/schedule/update_tidb_serverless_status_task.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
import app
|
||||||
|
from configs import dify_config
|
||||||
|
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||||
|
from models.dataset import TidbAuthBinding
|
||||||
|
|
||||||
|
|
||||||
|
@app.celery.task(queue="dataset")
|
||||||
|
def update_tidb_serverless_status_task():
|
||||||
|
click.echo(click.style("Update tidb serverless status task.", fg="green"))
|
||||||
|
start_at = time.perf_counter()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# check the number of idle tidb serverless
|
||||||
|
tidb_serverless_list = TidbAuthBinding.query.filter(
|
||||||
|
TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING"
|
||||||
|
).all()
|
||||||
|
if len(tidb_serverless_list) == 0:
|
||||||
|
break
|
||||||
|
# update tidb serverless status
|
||||||
|
iterations_per_thread = 20
|
||||||
|
update_clusters(tidb_serverless_list)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(click.style(f"Error: {e}", fg="red"))
|
||||||
|
break
|
||||||
|
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
click.echo(
|
||||||
|
click.style("Update tidb serverless status task success latency: {}".format(end_at - start_at), fg="green")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def update_clusters(tidb_serverless_list: list[TidbAuthBinding]):
|
||||||
|
try:
|
||||||
|
# batch 20
|
||||||
|
for i in range(0, len(tidb_serverless_list), 20):
|
||||||
|
items = tidb_serverless_list[i : i + 20]
|
||||||
|
TidbService.batch_update_tidb_serverless_cluster_status(
|
||||||
|
items,
|
||||||
|
dify_config.TIDB_PROJECT_ID,
|
||||||
|
dify_config.TIDB_API_URL,
|
||||||
|
dify_config.TIDB_IAM_API_URL,
|
||||||
|
dify_config.TIDB_PUBLIC_KEY,
|
||||||
|
dify_config.TIDB_PRIVATE_KEY,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(click.style(f"Error: {e}", fg="red"))
|
44
api/services/auth/jina.py
Normal file
44
api/services/auth/jina.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||||
|
|
||||||
|
|
||||||
|
class JinaAuth(ApiKeyAuthBase):
|
||||||
|
def __init__(self, credentials: dict):
|
||||||
|
super().__init__(credentials)
|
||||||
|
auth_type = credentials.get("auth_type")
|
||||||
|
if auth_type != "bearer":
|
||||||
|
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
|
||||||
|
self.api_key = credentials.get("config").get("api_key", None)
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("No API key provided")
|
||||||
|
|
||||||
|
def validate_credentials(self):
|
||||||
|
headers = self._prepare_headers()
|
||||||
|
options = {
|
||||||
|
"url": "https://example.com",
|
||||||
|
}
|
||||||
|
response = self._post_request("https://r.jina.ai", options, headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self._handle_error(response)
|
||||||
|
|
||||||
|
def _prepare_headers(self):
|
||||||
|
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||||
|
|
||||||
|
def _post_request(self, url, data, headers):
|
||||||
|
return requests.post(url, headers=headers, json=data)
|
||||||
|
|
||||||
|
def _handle_error(self, response):
|
||||||
|
if response.status_code in {402, 409, 500}:
|
||||||
|
error_message = response.json().get("error", "Unknown error occurred")
|
||||||
|
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||||
|
else:
|
||||||
|
if response.text:
|
||||||
|
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
||||||
|
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||||
|
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
|
@ -0,0 +1,40 @@
|
||||||
|
'use client'
|
||||||
|
import type { FC } from 'react'
|
||||||
|
import React from 'react'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import Checkbox from '@/app/components/base/checkbox'
|
||||||
|
import Tooltip from '@/app/components/base/tooltip'
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
className?: string
|
||||||
|
isChecked: boolean
|
||||||
|
onChange: (isChecked: boolean) => void
|
||||||
|
label: string
|
||||||
|
labelClassName?: string
|
||||||
|
tooltip?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const CheckboxWithLabel: FC<Props> = ({
|
||||||
|
className = '',
|
||||||
|
isChecked,
|
||||||
|
onChange,
|
||||||
|
label,
|
||||||
|
labelClassName,
|
||||||
|
tooltip,
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
<label className={cn(className, 'flex items-center h-7 space-x-2')}>
|
||||||
|
<Checkbox checked={isChecked} onCheck={() => onChange(!isChecked)} />
|
||||||
|
<div className={cn(labelClassName, 'text-sm font-normal text-gray-800')}>{label}</div>
|
||||||
|
{tooltip && (
|
||||||
|
<Tooltip
|
||||||
|
popupContent={
|
||||||
|
<div className='w-[200px]'>{tooltip}</div>
|
||||||
|
}
|
||||||
|
triggerClassName='ml-0.5 w-4 h-4'
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</label>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
export default React.memo(CheckboxWithLabel)
|
|
@ -0,0 +1,30 @@
|
||||||
|
'use client'
|
||||||
|
import type { FC } from 'react'
|
||||||
|
import React from 'react'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
className?: string
|
||||||
|
title: string
|
||||||
|
errorMsg?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const ErrorMessage: FC<Props> = ({
|
||||||
|
className,
|
||||||
|
title,
|
||||||
|
errorMsg,
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
<div className={cn(className, 'py-2 px-4 border-t border-gray-200 bg-[#FFFAEB]')}>
|
||||||
|
<div className='flex items-center h-5'>
|
||||||
|
<AlertTriangle className='mr-2 w-4 h-4 text-[#F79009]' />
|
||||||
|
<div className='text-sm font-medium text-[#DC6803]'>{title}</div>
|
||||||
|
</div>
|
||||||
|
{errorMsg && (
|
||||||
|
<div className='mt-1 pl-6 leading-[18px] text-xs font-normal text-gray-700'>{errorMsg}</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
export default React.memo(ErrorMessage)
|
|
@ -0,0 +1,54 @@
|
||||||
|
'use client'
|
||||||
|
import type { FC } from 'react'
|
||||||
|
import React from 'react'
|
||||||
|
import Input from './input'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import Tooltip from '@/app/components/base/tooltip'
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
className?: string
|
||||||
|
label: string
|
||||||
|
labelClassName?: string
|
||||||
|
value: string | number
|
||||||
|
onChange: (value: string | number) => void
|
||||||
|
isRequired?: boolean
|
||||||
|
placeholder?: string
|
||||||
|
isNumber?: boolean
|
||||||
|
tooltip?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const Field: FC<Props> = ({
|
||||||
|
className,
|
||||||
|
label,
|
||||||
|
labelClassName,
|
||||||
|
value,
|
||||||
|
onChange,
|
||||||
|
isRequired = false,
|
||||||
|
placeholder = '',
|
||||||
|
isNumber = false,
|
||||||
|
tooltip,
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
<div className={cn(className)}>
|
||||||
|
<div className='flex py-[7px]'>
|
||||||
|
<div className={cn(labelClassName, 'flex items-center h-[18px] text-[13px] font-medium text-gray-900')}>{label} </div>
|
||||||
|
{isRequired && <span className='ml-0.5 text-xs font-semibold text-[#D92D20]'>*</span>}
|
||||||
|
{tooltip && (
|
||||||
|
<Tooltip
|
||||||
|
popupContent={
|
||||||
|
<div className='w-[200px]'>{tooltip}</div>
|
||||||
|
}
|
||||||
|
triggerClassName='ml-0.5 w-4 h-4'
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<Input
|
||||||
|
value={value}
|
||||||
|
onChange={onChange}
|
||||||
|
placeholder={placeholder}
|
||||||
|
isNumber={isNumber}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
export default React.memo(Field)
|
|
@ -0,0 +1,58 @@
|
||||||
|
'use client'
|
||||||
|
import type { FC } from 'react'
|
||||||
|
import React, { useCallback } from 'react'
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
value: string | number
|
||||||
|
onChange: (value: string | number) => void
|
||||||
|
placeholder?: string
|
||||||
|
isNumber?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
const MIN_VALUE = 0
|
||||||
|
|
||||||
|
const Input: FC<Props> = ({
|
||||||
|
value,
|
||||||
|
onChange,
|
||||||
|
placeholder = '',
|
||||||
|
isNumber = false,
|
||||||
|
}) => {
|
||||||
|
const handleChange = useCallback((e: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
|
const value = e.target.value
|
||||||
|
if (isNumber) {
|
||||||
|
let numberValue = parseInt(value, 10) // integer only
|
||||||
|
if (isNaN(numberValue)) {
|
||||||
|
onChange('')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (numberValue < MIN_VALUE)
|
||||||
|
numberValue = MIN_VALUE
|
||||||
|
|
||||||
|
onChange(numberValue)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
onChange(value)
|
||||||
|
}, [isNumber, onChange])
|
||||||
|
|
||||||
|
const otherOption = (() => {
|
||||||
|
if (isNumber) {
|
||||||
|
return {
|
||||||
|
min: MIN_VALUE,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
return (
|
||||||
|
<input
|
||||||
|
type={isNumber ? 'number' : 'text'}
|
||||||
|
{...otherOption}
|
||||||
|
value={value}
|
||||||
|
onChange={handleChange}
|
||||||
|
className='flex h-9 w-full py-1 px-2 rounded-lg text-xs leading-normal bg-gray-100 caret-primary-600 hover:bg-gray-100 focus:ring-1 focus:ring-inset focus:ring-gray-200 focus-visible:outline-none focus:bg-gray-50 placeholder:text-gray-400'
|
||||||
|
placeholder={placeholder}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
export default React.memo(Input)
|
|
@ -0,0 +1,55 @@
|
||||||
|
'use client'
|
||||||
|
import { useBoolean } from 'ahooks'
|
||||||
|
import type { FC } from 'react'
|
||||||
|
import React, { useEffect } from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import { Settings04 } from '@/app/components/base/icons/src/vender/line/general'
|
||||||
|
import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows'
|
||||||
|
const I18N_PREFIX = 'datasetCreation.stepOne.website'
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
className?: string
|
||||||
|
children: React.ReactNode
|
||||||
|
controlFoldOptions?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
const OptionsWrap: FC<Props> = ({
|
||||||
|
className = '',
|
||||||
|
children,
|
||||||
|
controlFoldOptions,
|
||||||
|
}) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
|
||||||
|
const [fold, {
|
||||||
|
toggle: foldToggle,
|
||||||
|
setTrue: foldHide,
|
||||||
|
}] = useBoolean(false)
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (controlFoldOptions)
|
||||||
|
foldHide()
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [controlFoldOptions])
|
||||||
|
return (
|
||||||
|
<div className={cn(className, !fold ? 'mb-0' : 'mb-3')}>
|
||||||
|
<div
|
||||||
|
className='flex justify-between items-center h-[26px] py-1 cursor-pointer select-none'
|
||||||
|
onClick={foldToggle}
|
||||||
|
>
|
||||||
|
<div className='flex items-center text-gray-700'>
|
||||||
|
<Settings04 className='mr-1 w-4 h-4' />
|
||||||
|
<div className='text-[13px] font-semibold text-gray-800 uppercase'>{t(`${I18N_PREFIX}.options`)}</div>
|
||||||
|
</div>
|
||||||
|
<ChevronRight className={cn(!fold && 'rotate-90', 'w-4 h-4 text-gray-500')} />
|
||||||
|
</div>
|
||||||
|
{!fold && (
|
||||||
|
<div className='mb-4'>
|
||||||
|
{children}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
export default React.memo(OptionsWrap)
|
|
@ -0,0 +1,48 @@
|
||||||
|
'use client'
|
||||||
|
import type { FC } from 'react'
|
||||||
|
import React, { useCallback, useState } from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import Input from './input'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
|
||||||
|
const I18N_PREFIX = 'datasetCreation.stepOne.website'
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
isRunning: boolean
|
||||||
|
onRun: (url: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const UrlInput: FC<Props> = ({
|
||||||
|
isRunning,
|
||||||
|
onRun,
|
||||||
|
}) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const [url, setUrl] = useState('')
|
||||||
|
const handleUrlChange = useCallback((url: string | number) => {
|
||||||
|
setUrl(url as string)
|
||||||
|
}, [])
|
||||||
|
const handleOnRun = useCallback(() => {
|
||||||
|
if (isRunning)
|
||||||
|
return
|
||||||
|
onRun(url)
|
||||||
|
}, [isRunning, onRun, url])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className='flex items-center justify-between'>
|
||||||
|
<Input
|
||||||
|
value={url}
|
||||||
|
onChange={handleUrlChange}
|
||||||
|
placeholder='https://docs.dify.ai'
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
variant='primary'
|
||||||
|
onClick={handleOnRun}
|
||||||
|
className='ml-2'
|
||||||
|
loading={isRunning}
|
||||||
|
>
|
||||||
|
{!isRunning ? t(`${I18N_PREFIX}.run`) : ''}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
export default React.memo(UrlInput)
|
|
@ -0,0 +1,40 @@
|
||||||
|
'use client'
|
||||||
|
import type { FC } from 'react'
|
||||||
|
import React, { useCallback } from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import type { CrawlResultItem as CrawlResultItemType } from '@/models/datasets'
|
||||||
|
import Checkbox from '@/app/components/base/checkbox'
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
payload: CrawlResultItemType
|
||||||
|
isChecked: boolean
|
||||||
|
isPreview: boolean
|
||||||
|
onCheckChange: (checked: boolean) => void
|
||||||
|
onPreview: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const CrawledResultItem: FC<Props> = ({
|
||||||
|
isPreview,
|
||||||
|
payload,
|
||||||
|
isChecked,
|
||||||
|
onCheckChange,
|
||||||
|
onPreview,
|
||||||
|
}) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
|
||||||
|
const handleCheckChange = useCallback(() => {
|
||||||
|
onCheckChange(!isChecked)
|
||||||
|
}, [isChecked, onCheckChange])
|
||||||
|
return (
|
||||||
|
<div className={cn(isPreview ? 'border-[#D1E0FF] bg-primary-50 shadow-xs' : 'group hover:bg-gray-100', 'rounded-md px-2 py-[5px] cursor-pointer border border-transparent')}>
|
||||||
|
<div className='flex items-center h-5'>
|
||||||
|
<Checkbox className='group-hover:border-2 group-hover:border-primary-600 mr-2 shrink-0' checked={isChecked} onCheck={handleCheckChange} />
|
||||||
|
<div className='grow w-0 truncate text-sm font-medium text-gray-700' title={payload.title}>{payload.title}</div>
|
||||||
|
<div onClick={onPreview} className='hidden group-hover:flex items-center h-6 px-2 text-xs rounded-md font-medium text-gray-500 uppercase hover:bg-gray-50'>{t('datasetCreation.stepOne.website.preview')}</div>
|
||||||
|
</div>
|
||||||
|
<div className='mt-0.5 truncate pl-6 leading-[18px] text-xs font-normal text-gray-500' title={payload.source_url}>{payload.source_url}</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
export default React.memo(CrawledResultItem)
|
|
@ -0,0 +1,87 @@
|
||||||
|
'use client'
|
||||||
|
import type { FC } from 'react'
|
||||||
|
import React, { useCallback } from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import CheckboxWithLabel from './base/checkbox-with-label'
|
||||||
|
import CrawledResultItem from './crawled-result-item'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import type { CrawlResultItem } from '@/models/datasets'
|
||||||
|
|
||||||
|
const I18N_PREFIX = 'datasetCreation.stepOne.website'
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
className?: string
|
||||||
|
list: CrawlResultItem[]
|
||||||
|
checkedList: CrawlResultItem[]
|
||||||
|
onSelectedChange: (selected: CrawlResultItem[]) => void
|
||||||
|
onPreview: (payload: CrawlResultItem) => void
|
||||||
|
usedTime: number
|
||||||
|
}
|
||||||
|
|
||||||
|
const CrawledResult: FC<Props> = ({
|
||||||
|
className = '',
|
||||||
|
list,
|
||||||
|
checkedList,
|
||||||
|
onSelectedChange,
|
||||||
|
onPreview,
|
||||||
|
usedTime,
|
||||||
|
}) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
|
||||||
|
const isCheckAll = checkedList.length === list.length
|
||||||
|
|
||||||
|
const handleCheckedAll = useCallback(() => {
|
||||||
|
if (!isCheckAll)
|
||||||
|
onSelectedChange(list)
|
||||||
|
|
||||||
|
else
|
||||||
|
onSelectedChange([])
|
||||||
|
}, [isCheckAll, list, onSelectedChange])
|
||||||
|
|
||||||
|
const handleItemCheckChange = useCallback((item: CrawlResultItem) => {
|
||||||
|
return (checked: boolean) => {
|
||||||
|
if (checked)
|
||||||
|
onSelectedChange([...checkedList, item])
|
||||||
|
|
||||||
|
else
|
||||||
|
onSelectedChange(checkedList.filter(checkedItem => checkedItem.source_url !== item.source_url))
|
||||||
|
}
|
||||||
|
}, [checkedList, onSelectedChange])
|
||||||
|
|
||||||
|
const [previewIndex, setPreviewIndex] = React.useState<number>(-1)
|
||||||
|
const handlePreview = useCallback((index: number) => {
|
||||||
|
return () => {
|
||||||
|
setPreviewIndex(index)
|
||||||
|
onPreview(list[index])
|
||||||
|
}
|
||||||
|
}, [list, onPreview])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn(className, 'border-t border-gray-200')}>
|
||||||
|
<div className='flex items-center justify-between h-[34px] px-4 bg-gray-50 shadow-xs border-b-[0.5px] border-black/8 text-xs font-normal text-gray-700'>
|
||||||
|
<CheckboxWithLabel
|
||||||
|
isChecked={isCheckAll}
|
||||||
|
onChange={handleCheckedAll} label={isCheckAll ? t(`${I18N_PREFIX}.resetAll`) : t(`${I18N_PREFIX}.selectAll`)}
|
||||||
|
labelClassName='!font-medium'
|
||||||
|
/>
|
||||||
|
<div>{t(`${I18N_PREFIX}.scrapTimeInfo`, {
|
||||||
|
total: list.length,
|
||||||
|
time: usedTime.toFixed(1),
|
||||||
|
})}</div>
|
||||||
|
</div>
|
||||||
|
<div className='p-2'>
|
||||||
|
{list.map((item, index) => (
|
||||||
|
<CrawledResultItem
|
||||||
|
key={item.source_url}
|
||||||
|
isPreview={index === previewIndex}
|
||||||
|
onPreview={handlePreview(index)}
|
||||||
|
payload={item}
|
||||||
|
isChecked={checkedList.some(checkedItem => checkedItem.source_url === item.source_url)}
|
||||||
|
onCheckChange={handleItemCheckChange(item)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
export default React.memo(CrawledResult)
|
|
@ -0,0 +1,37 @@
|
||||||
|
'use client'
|
||||||
|
import type { FC } from 'react'
|
||||||
|
import React from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import { RowStruct } from '@/app/components/base/icons/src/public/other'
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
className?: string
|
||||||
|
crawledNum: number
|
||||||
|
totalNum: number
|
||||||
|
}
|
||||||
|
|
||||||
|
const Crawling: FC<Props> = ({
|
||||||
|
className = '',
|
||||||
|
crawledNum,
|
||||||
|
totalNum,
|
||||||
|
}) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn(className, 'border-t border-gray-200')}>
|
||||||
|
<div className='flex items-center h-[34px] px-4 bg-gray-50 shadow-xs border-b-[0.5px] border-black/8 text-xs font-normal text-gray-700'>
|
||||||
|
{t('datasetCreation.stepOne.website.totalPageScraped')} {crawledNum}/{totalNum}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className='p-2'>
|
||||||
|
{['', '', '', ''].map((item, index) => (
|
||||||
|
<div className='py-[5px]' key={index}>
|
||||||
|
<RowStruct />
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
export default React.memo(Crawling)
|
|
@ -0,0 +1,24 @@
|
||||||
|
import type { CrawlResultItem } from '@/models/datasets'
|
||||||
|
|
||||||
|
const result: CrawlResultItem[] = [
|
||||||
|
{
|
||||||
|
title: 'Start the frontend Docker container separately',
|
||||||
|
markdown: 'Markdown 1',
|
||||||
|
description: 'Description 1',
|
||||||
|
source_url: 'https://example.com/1',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: 'Advanced Tool Integration',
|
||||||
|
markdown: 'Markdown 2',
|
||||||
|
description: 'Description 2',
|
||||||
|
source_url: 'https://example.com/2',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: 'Local Source Code Start | English | Dify',
|
||||||
|
markdown: 'Markdown 3',
|
||||||
|
description: 'Description 3',
|
||||||
|
source_url: 'https://example.com/3',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
export default result
|
Loading…
Reference in New Issue
Block a user