add tidb on qdrant type (#9831)

Co-authored-by: Zhaofeng Miao <522856232@qq.com>
This commit is contained in:
Jyong 2024-10-25 13:57:03 +08:00 committed by GitHub
parent fc2297a2ca
commit 18106a4fc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 1648 additions and 1 deletions

View File

@ -571,6 +571,11 @@ class DataSetConfig(BaseSettings):
default=False, default=False,
) )
TIDB_SERVERLESS_NUMBER: PositiveInt = Field(
description="number of tidb serverless cluster",
default=500,
)
class WorkspaceConfig(BaseSettings): class WorkspaceConfig(BaseSettings):
""" """

View File

@ -27,6 +27,7 @@ from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
from configs.middleware.vdb.qdrant_config import QdrantConfig from configs.middleware.vdb.qdrant_config import QdrantConfig
from configs.middleware.vdb.relyt_config import RelytConfig from configs.middleware.vdb.relyt_config import RelytConfig
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
from configs.middleware.vdb.upstash_config import UpstashConfig from configs.middleware.vdb.upstash_config import UpstashConfig
from configs.middleware.vdb.vikingdb_config import VikingDBConfig from configs.middleware.vdb.vikingdb_config import VikingDBConfig
@ -54,6 +55,11 @@ class VectorStoreConfig(BaseSettings):
default=None, default=None,
) )
VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field(
description="Enable whitelist for vector store.",
default=False,
)
class KeywordStoreConfig(BaseSettings): class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field( KEYWORD_STORE: str = Field(
@ -248,5 +254,6 @@ class MiddlewareConfig(
InternalTestConfig, InternalTestConfig,
VikingDBConfig, VikingDBConfig,
UpstashConfig, UpstashConfig,
TidbOnQdrantConfig,
): ):
pass pass

View File

@ -0,0 +1,65 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class TidbOnQdrantConfig(BaseSettings):
"""
Tidb on Qdrant configs
"""
TIDB_ON_QDRANT_URL: Optional[str] = Field(
description="Tidb on Qdrant url",
default=None,
)
TIDB_ON_QDRANT_API_KEY: Optional[str] = Field(
description="Tidb on Qdrant api key",
default=None,
)
TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
description="Tidb on Qdrant client timeout in seconds",
default=20,
)
TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field(
description="whether enable grpc support for Tidb on Qdrant connection",
default=False,
)
TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field(
description="Tidb on Qdrant grpc port",
default=6334,
)
TIDB_PUBLIC_KEY: Optional[str] = Field(
description="Tidb account public key",
default=None,
)
TIDB_PRIVATE_KEY: Optional[str] = Field(
description="Tidb account private key",
default=None,
)
TIDB_API_URL: Optional[str] = Field(
description="Tidb API url",
default=None,
)
TIDB_IAM_API_URL: Optional[str] = Field(
description="Tidb IAM API url",
default=None,
)
TIDB_REGION: Optional[str] = Field(
description="Tidb serverless region",
default="regions/aws-us-east-1",
)
TIDB_PROJECT_ID: Optional[str] = Field(
description="Tidb project id",
default=None,
)

View File

@ -639,6 +639,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.ORACLE | VectorType.ORACLE
| VectorType.ELASTICSEARCH | VectorType.ELASTICSEARCH
| VectorType.PGVECTOR | VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT
): ):
return { return {
"retrieval_method": [ "retrieval_method": [

View File

@ -0,0 +1,17 @@
from typing import Optional
from pydantic import BaseModel
class ClusterEntity(BaseModel):
"""
Model Config Entity.
"""
name: str
cluster_id: str
displayName: str
region: str
spendingLimit: Optional[int] = 1000
version: str
createdBy: str

View File

@ -0,0 +1,526 @@
import json
import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import qdrant_client
import requests
from flask import current_app
from pydantic import BaseModel
from qdrant_client.http import models as rest
from qdrant_client.http.models import (
FilterSelector,
HnswConfigDiff,
PayloadSchemaType,
TextIndexParams,
TextIndexType,
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, TidbAuthBinding
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
class TidbOnQdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str] = None
timeout: float = 20
root_path: Optional[str] = None
grpc_port: int = 6334
prefer_grpc: bool = False
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {"path": path}
else:
return {
"url": self.endpoint,
"api_key": self.api_key,
"timeout": self.timeout,
"verify": False,
"grpc_port": self.grpc_port,
"prefer_grpc": self.prefer_grpc,
}
class TidbConfig(BaseModel):
api_url: str
public_key: str
private_key: str
class TidbOnQdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
self._distance_func = distance_func.upper()
self._group_id = group_id
def get_type(self) -> str:
return VectorType.TIDB_ON_QDRANT
def to_index_struct(self) -> dict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
# get embedding vector size
vector_size = len(embeddings[0])
# get collection name
collection_name = self._collection_name
# create collection
self.create_collection(collection_name, vector_size)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
lock_name = "vector_indexing_lock_{}".format(collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
collection_name = collection_name or uuid.uuid4().hex
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(
m=0,
payload_m=16,
ef_construct=100,
full_scan_threshold=10000,
max_indexing_threads=0,
on_disk=False,
)
self._client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
)
# create group_id payload index
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
embeddings_iterator = iter(embeddings)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = list(islice(embeddings_iterator, batch_size))
points = [
rest.PointStruct(
id=point_id,
vector=vector,
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
batch_embeddings,
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
group_id,
Field.GROUP_KEY.value,
),
)
]
yield batch_ids, points
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str,
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
return payloads
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
try:
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete(self):
from qdrant_client.http.exceptions import UnexpectedResponse
try:
self._client.delete_collection(collection_name=self._collection_name)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete_by_ids(self, ids: list[str]) -> None:
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
for node_id in ids:
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if self._collection_name not in all_collection_name:
return False
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
],
)
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
query_filter=filter,
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", 0.0),
)
docs = []
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs most similar by bm25.
Returns:
List of documents most similar to the query text and distance for each.
"""
from qdrant_client.http import models
scroll_filter = models.Filter(
must=[
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=kwargs.get("top_k", 2),
with_payload=True,
with_vectors=True,
)
results = response[0]
documents = []
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
document.metadata["vector"] = result.vector
documents.append(document)
return documents
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID,
dify_config.TIDB_API_URL,
dify_config.TIDB_IAM_API_URL,
dify_config.TIDB_PUBLIC_KEY,
dify_config.TIDB_PRIVATE_KEY,
dify_config.TIDB_REGION,
)
new_tidb_auth_binding = TidbAuthBinding(
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
tenant_id=dataset.tenant_id,
active=True,
status="ACTIVE",
)
db.session.add(new_tidb_auth_binding)
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name))
config = current_app.config
return TidbOnQdrantVector(
collection_name=collection_name,
group_id=dataset.id,
config=TidbOnQdrantConfig(
endpoint=dify_config.TIDB_ON_QDRANT_URL,
api_key=TIDB_ON_QDRANT_API_KEY,
root_path=config.root_path,
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
),
)
def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str):
"""
Creates a new TiDB Serverless cluster.
:param tidb_config: The configuration for the TiDB Cloud API.
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": "1372813089454548012",
}
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
response = requests.post(
f"{tidb_config.api_url}/clusters",
json=cluster_data,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str):
"""
Changes the root password of a specific TiDB Serverless cluster.
:param tidb_config: The configuration for the TiDB Cloud API.
:param cluster_id: The ID of the cluster for which the password is to be changed (required).
:param new_password: The new password for the root user (required).
:return: The response from the API.
"""
body = {"password": new_password}
response = requests.put(
f"{tidb_config.api_url}/clusters/{cluster_id}/password",
json=body,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()

View File

@ -0,0 +1,250 @@
import time
import uuid
import requests
from requests.auth import HTTPDigestAuth
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
class TidbService:
@staticmethod
def create_tidb_serverless_cluster(
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
):
"""
Creates a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": project_id,
}
spending_limit = {
"monthly": 100,
}
password = str(uuid.uuid4()).replace("-", "")[:16]
display_name = str(uuid.uuid4()).replace("-", "")[:16]
cluster_data = {
"displayName": display_name,
"region": region_object,
"labels": labels,
"spendingLimit": spending_limit,
"rootPassword": password,
}
response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
cluster_id = response_data["clusterId"]
retry_count = 0
max_retries = 30
while retry_count < max_retries:
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
if cluster_response["state"] == "ACTIVE":
user_prefix = cluster_response["userPrefix"]
return {
"cluster_id": cluster_id,
"cluster_name": display_name,
"account": f"{user_prefix}.root",
"password": password,
}
time.sleep(30) # wait 30 seconds before retrying
retry_count += 1
else:
response.raise_for_status()
@staticmethod
def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
"""
Deletes a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster to be deleted (required).
:return: The response from the API.
"""
response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
"""
Deletes a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster to be deleted (required).
:return: The response from the API.
"""
response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def change_tidb_serverless_root_password(
api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str
):
"""
Changes the root password of a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster for which the password is to be changed (required).+
:param account: The account for which the password is to be changed (required).
:param new_password: The new password for the root user (required).
:return: The response from the API.
"""
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
response = requests.patch(
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
json=body,
auth=HTTPDigestAuth(public_key, private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def batch_update_tidb_serverless_cluster_status(
tidb_serverless_list: list[TidbAuthBinding],
project_id: str,
api_url: str,
iam_url: str,
public_key: str,
private_key: str,
) -> list[dict]:
"""
Update the status of a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
clusters = []
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "FULL"}
response = requests.get(
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
)
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
for item in response_data["clusters"]:
state = item["state"]
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = "ACTIVE"
cluster_info.account = f"{userPrefix}.root"
db.session.add(cluster_info)
db.session.commit()
else:
response.raise_for_status()
@staticmethod
def batch_create_tidb_serverless_cluster(
batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
) -> list[dict]:
"""
Creates a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
clusters = []
for _ in range(batch_size):
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": project_id,
}
spending_limit = {
"monthly": 10,
}
password = str(uuid.uuid4()).replace("-", "")[:16]
display_name = str(uuid.uuid4()).replace("-", "")
cluster_data = {
"cluster": {
"displayName": display_name,
"region": region_object,
"labels": labels,
"spendingLimit": spending_limit,
"rootPassword": password,
}
}
cache_key = f"tidb_serverless_cluster_password:{display_name}"
redis_client.setex(cache_key, 3600, password)
clusters.append(cluster_data)
request_body = {"requests": clusters}
response = requests.post(
f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key)
)
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
for item in response_data["clusters"]:
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
password = redis_client.get(cache_key)
if not password:
continue
cluster_info = {
"cluster_id": item["clusterId"],
"cluster_name": item["displayName"],
"account": "root",
"password": password.decode("utf-8"),
}
cluster_infos.append(cluster_info)
return cluster_infos
else:
response.raise_for_status()

View File

@ -9,8 +9,9 @@ from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset, Whitelist
class AbstractVectorFactory(ABC): class AbstractVectorFactory(ABC):
@ -35,8 +36,18 @@ class Vector:
def _init_vector(self) -> BaseVector: def _init_vector(self) -> BaseVector:
vector_type = dify_config.VECTOR_STORE vector_type = dify_config.VECTOR_STORE
if self._dataset.index_struct_dict: if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict["type"] vector_type = self._dataset.index_struct_dict["type"]
else:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
)
if whitelist:
vector_type = VectorType.TIDB_ON_QDRANT
if not vector_type: if not vector_type:
raise ValueError("Vector store must be specified.") raise ValueError("Vector store must be specified.")
@ -115,6 +126,10 @@ class Vector:
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
return UpstashVectorFactory return UpstashVectorFactory
case VectorType.TIDB_ON_QDRANT:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
return TidbOnQdrantVectorFactory
case _: case _:
raise ValueError(f"Vector store {vector_type} is not supported.") raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -19,3 +19,4 @@ class VectorType(str, Enum):
BAIDU = "baidu" BAIDU = "baidu"
VIKINGDB = "vikingdb" VIKINGDB = "vikingdb"
UPSTASH = "upstash" UPSTASH = "upstash"
TIDB_ON_QDRANT = "tidb_on_qdrant"

View File

@ -1,6 +1,7 @@
from datetime import timedelta from datetime import timedelta
from celery import Celery, Task from celery import Celery, Task
from celery.schedules import crontab
from flask import Flask from flask import Flask
from configs import dify_config from configs import dify_config
@ -55,6 +56,8 @@ def init_app(app: Flask) -> Celery:
imports = [ imports = [
"schedule.clean_embedding_cache_task", "schedule.clean_embedding_cache_task",
"schedule.clean_unused_datasets_task", "schedule.clean_unused_datasets_task",
"schedule.create_tidb_serverless_task",
"schedule.update_tidb_serverless_status_task",
] ]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = { beat_schedule = {
@ -66,6 +69,14 @@ def init_app(app: Flask) -> Celery:
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
"schedule": timedelta(days=day), "schedule": timedelta(days=day),
}, },
"create_tidb_serverless_task": {
"task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task",
"schedule": crontab(minute="0", hour="*"),
},
"update_tidb_serverless_status_task": {
"task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task",
"schedule": crontab(minute="30", hour="*"),
},
} }
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)

View File

@ -0,0 +1,51 @@
"""add-tidb-auth-binding
Revision ID: 0251a1c768cc
Revises: 63a83fcf12ba
Create Date: 2024-08-15 09:56:59.012490
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = '0251a1c768cc'
down_revision = 'bbadea11becb'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tidb_auth_bindings',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
sa.Column('cluster_id', sa.String(length=255), nullable=False),
sa.Column('cluster_name', sa.String(length=255), nullable=False),
sa.Column('active', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'CREATING'::character varying"), nullable=False),
sa.Column('account', sa.String(length=255), nullable=False),
sa.Column('password', sa.String(length=255), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tidb_auth_bindings_pkey')
)
with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op:
batch_op.create_index('tidb_auth_bindings_active_idx', ['active'], unique=False)
batch_op.create_index('tidb_auth_bindings_status_idx', ['status'], unique=False)
batch_op.create_index('tidb_auth_bindings_created_at_idx', ['created_at'], unique=False)
batch_op.create_index('tidb_auth_bindings_tenant_idx', ['tenant_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op:
batch_op.drop_index('tidb_auth_bindings_tenant_idx')
batch_op.drop_index('tidb_auth_bindings_created_at_idx')
batch_op.drop_index('tidb_auth_bindings_active_idx')
batch_op.drop_index('tidb_auth_bindings_status_idx')
op.drop_table('tidb_auth_bindings')
# ### end Alembic commands ###

View File

@ -0,0 +1,42 @@
"""add_white_list
Revision ID: 43fa78bc3b7d
Revises: 0251a1c768cc
Create Date: 2024-10-22 09:59:23.713716
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '43fa78bc3b7d'
down_revision = '0251a1c768cc'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('whitelists',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
sa.Column('category', sa.String(length=255), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='whitelists_pkey')
)
with op.batch_alter_table('whitelists', schema=None) as batch_op:
batch_op.create_index('whitelists_tenant_idx', ['tenant_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('whitelists', schema=None) as batch_op:
batch_op.drop_index('whitelists_tenant_idx')
op.drop_table('whitelists')
# ### end Alembic commands ###

View File

@ -704,6 +704,38 @@ class DatasetCollectionBinding(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class TidbAuthBinding(db.Model):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
db.Index("tidb_auth_bindings_active_idx", "active"),
db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
db.Index("tidb_auth_bindings_status_idx", "status"),
)
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=True)
cluster_id = db.Column(db.String(255), nullable=False)
cluster_name = db.Column(db.String(255), nullable=False)
active = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))
account = db.Column(db.String(255), nullable=False)
password = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class Whitelist(db.Model):
__tablename__ = "whitelists"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
db.Index("whitelists_tenant_idx", "tenant_id"),
)
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=True)
category = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class DatasetPermission(db.Model): class DatasetPermission(db.Model):
__tablename__ = "dataset_permissions" __tablename__ = "dataset_permissions"
__table_args__ = ( __table_args__ = (

View File

@ -0,0 +1,56 @@
import time
import click
import app
from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
@app.celery.task(queue="dataset")
def create_tidb_serverless_task():
click.echo(click.style("Start create tidb serverless task.", fg="green"))
tidb_serverless_number = dify_config.TIDB_SERVERLESS_NUMBER
start_at = time.perf_counter()
while True:
try:
# check the number of idle tidb serverless
idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count()
if idle_tidb_serverless_number >= tidb_serverless_number:
break
# create tidb serverless
iterations_per_thread = 20
create_clusters(iterations_per_thread)
except Exception as e:
click.echo(click.style(f"Error: {e}", fg="red"))
break
end_at = time.perf_counter()
click.echo(click.style("Create tidb serverless task success latency: {}".format(end_at - start_at), fg="green"))
def create_clusters(batch_size):
try:
new_clusters = TidbService.batch_create_tidb_serverless_cluster(
batch_size,
dify_config.TIDB_PROJECT_ID,
dify_config.TIDB_API_URL,
dify_config.TIDB_IAM_API_URL,
dify_config.TIDB_PUBLIC_KEY,
dify_config.TIDB_PRIVATE_KEY,
dify_config.TIDB_REGION,
)
for new_cluster in new_clusters:
tidb_auth_binding = TidbAuthBinding(
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
)
db.session.add(tidb_auth_binding)
db.session.commit()
except Exception as e:
click.echo(click.style(f"Error: {e}", fg="red"))

View File

@ -0,0 +1,51 @@
import time
import click
import app
from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from models.dataset import TidbAuthBinding
@app.celery.task(queue="dataset")
def update_tidb_serverless_status_task():
click.echo(click.style("Update tidb serverless status task.", fg="green"))
start_at = time.perf_counter()
while True:
try:
# check the number of idle tidb serverless
tidb_serverless_list = TidbAuthBinding.query.filter(
TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING"
).all()
if len(tidb_serverless_list) == 0:
break
# update tidb serverless status
iterations_per_thread = 20
update_clusters(tidb_serverless_list)
except Exception as e:
click.echo(click.style(f"Error: {e}", fg="red"))
break
end_at = time.perf_counter()
click.echo(
click.style("Update tidb serverless status task success latency: {}".format(end_at - start_at), fg="green")
)
def update_clusters(tidb_serverless_list: list[TidbAuthBinding]):
try:
# batch 20
for i in range(0, len(tidb_serverless_list), 20):
items = tidb_serverless_list[i : i + 20]
TidbService.batch_update_tidb_serverless_cluster_status(
items,
dify_config.TIDB_PROJECT_ID,
dify_config.TIDB_API_URL,
dify_config.TIDB_IAM_API_URL,
dify_config.TIDB_PUBLIC_KEY,
dify_config.TIDB_PRIVATE_KEY,
)
except Exception as e:
click.echo(click.style(f"Error: {e}", fg="red"))

44
api/services/auth/jina.py Normal file
View File

@ -0,0 +1,44 @@
import json
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
self.api_key = credentials.get("config").get("api_key", None)
if not self.api_key:
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
options = {
"url": "https://example.com",
}
response = self._post_request("https://r.jina.ai", options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

View File

@ -0,0 +1,40 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import cn from '@/utils/classnames'
import Checkbox from '@/app/components/base/checkbox'
import Tooltip from '@/app/components/base/tooltip'
type Props = {
className?: string
isChecked: boolean
onChange: (isChecked: boolean) => void
label: string
labelClassName?: string
tooltip?: string
}
const CheckboxWithLabel: FC<Props> = ({
className = '',
isChecked,
onChange,
label,
labelClassName,
tooltip,
}) => {
return (
<label className={cn(className, 'flex items-center h-7 space-x-2')}>
<Checkbox checked={isChecked} onCheck={() => onChange(!isChecked)} />
<div className={cn(labelClassName, 'text-sm font-normal text-gray-800')}>{label}</div>
{tooltip && (
<Tooltip
popupContent={
<div className='w-[200px]'>{tooltip}</div>
}
triggerClassName='ml-0.5 w-4 h-4'
/>
)}
</label>
)
}
export default React.memo(CheckboxWithLabel)

View File

@ -0,0 +1,30 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import cn from '@/utils/classnames'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
type Props = {
className?: string
title: string
errorMsg?: string
}
const ErrorMessage: FC<Props> = ({
className,
title,
errorMsg,
}) => {
return (
<div className={cn(className, 'py-2 px-4 border-t border-gray-200 bg-[#FFFAEB]')}>
<div className='flex items-center h-5'>
<AlertTriangle className='mr-2 w-4 h-4 text-[#F79009]' />
<div className='text-sm font-medium text-[#DC6803]'>{title}</div>
</div>
{errorMsg && (
<div className='mt-1 pl-6 leading-[18px] text-xs font-normal text-gray-700'>{errorMsg}</div>
)}
</div>
)
}
export default React.memo(ErrorMessage)

View File

@ -0,0 +1,54 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import Input from './input'
import cn from '@/utils/classnames'
import Tooltip from '@/app/components/base/tooltip'
type Props = {
className?: string
label: string
labelClassName?: string
value: string | number
onChange: (value: string | number) => void
isRequired?: boolean
placeholder?: string
isNumber?: boolean
tooltip?: string
}
const Field: FC<Props> = ({
className,
label,
labelClassName,
value,
onChange,
isRequired = false,
placeholder = '',
isNumber = false,
tooltip,
}) => {
return (
<div className={cn(className)}>
<div className='flex py-[7px]'>
<div className={cn(labelClassName, 'flex items-center h-[18px] text-[13px] font-medium text-gray-900')}>{label} </div>
{isRequired && <span className='ml-0.5 text-xs font-semibold text-[#D92D20]'>*</span>}
{tooltip && (
<Tooltip
popupContent={
<div className='w-[200px]'>{tooltip}</div>
}
triggerClassName='ml-0.5 w-4 h-4'
/>
)}
</div>
<Input
value={value}
onChange={onChange}
placeholder={placeholder}
isNumber={isNumber}
/>
</div>
)
}
export default React.memo(Field)

View File

@ -0,0 +1,58 @@
'use client'
import type { FC } from 'react'
import React, { useCallback } from 'react'
type Props = {
value: string | number
onChange: (value: string | number) => void
placeholder?: string
isNumber?: boolean
}
const MIN_VALUE = 0
const Input: FC<Props> = ({
value,
onChange,
placeholder = '',
isNumber = false,
}) => {
const handleChange = useCallback((e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value
if (isNumber) {
let numberValue = parseInt(value, 10) // integer only
if (isNaN(numberValue)) {
onChange('')
return
}
if (numberValue < MIN_VALUE)
numberValue = MIN_VALUE
onChange(numberValue)
return
}
onChange(value)
}, [isNumber, onChange])
const otherOption = (() => {
if (isNumber) {
return {
min: MIN_VALUE,
}
}
return {
}
})()
return (
<input
type={isNumber ? 'number' : 'text'}
{...otherOption}
value={value}
onChange={handleChange}
className='flex h-9 w-full py-1 px-2 rounded-lg text-xs leading-normal bg-gray-100 caret-primary-600 hover:bg-gray-100 focus:ring-1 focus:ring-inset focus:ring-gray-200 focus-visible:outline-none focus:bg-gray-50 placeholder:text-gray-400'
placeholder={placeholder}
/>
)
}
export default React.memo(Input)

View File

@ -0,0 +1,55 @@
'use client'
import { useBoolean } from 'ahooks'
import type { FC } from 'react'
import React, { useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import cn from '@/utils/classnames'
import { Settings04 } from '@/app/components/base/icons/src/vender/line/general'
import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows'
const I18N_PREFIX = 'datasetCreation.stepOne.website'
type Props = {
className?: string
children: React.ReactNode
controlFoldOptions?: number
}
const OptionsWrap: FC<Props> = ({
className = '',
children,
controlFoldOptions,
}) => {
const { t } = useTranslation()
const [fold, {
toggle: foldToggle,
setTrue: foldHide,
}] = useBoolean(false)
useEffect(() => {
if (controlFoldOptions)
foldHide()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [controlFoldOptions])
return (
<div className={cn(className, !fold ? 'mb-0' : 'mb-3')}>
<div
className='flex justify-between items-center h-[26px] py-1 cursor-pointer select-none'
onClick={foldToggle}
>
<div className='flex items-center text-gray-700'>
<Settings04 className='mr-1 w-4 h-4' />
<div className='text-[13px] font-semibold text-gray-800 uppercase'>{t(`${I18N_PREFIX}.options`)}</div>
</div>
<ChevronRight className={cn(!fold && 'rotate-90', 'w-4 h-4 text-gray-500')} />
</div>
{!fold && (
<div className='mb-4'>
{children}
</div>
)}
</div>
)
}
export default React.memo(OptionsWrap)

View File

@ -0,0 +1,48 @@
'use client'
import type { FC } from 'react'
import React, { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Input from './input'
import Button from '@/app/components/base/button'
const I18N_PREFIX = 'datasetCreation.stepOne.website'
type Props = {
isRunning: boolean
onRun: (url: string) => void
}
const UrlInput: FC<Props> = ({
isRunning,
onRun,
}) => {
const { t } = useTranslation()
const [url, setUrl] = useState('')
const handleUrlChange = useCallback((url: string | number) => {
setUrl(url as string)
}, [])
const handleOnRun = useCallback(() => {
if (isRunning)
return
onRun(url)
}, [isRunning, onRun, url])
return (
<div className='flex items-center justify-between'>
<Input
value={url}
onChange={handleUrlChange}
placeholder='https://docs.dify.ai'
/>
<Button
variant='primary'
onClick={handleOnRun}
className='ml-2'
loading={isRunning}
>
{!isRunning ? t(`${I18N_PREFIX}.run`) : ''}
</Button>
</div>
)
}
export default React.memo(UrlInput)

View File

@ -0,0 +1,40 @@
'use client'
import type { FC } from 'react'
import React, { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import cn from '@/utils/classnames'
import type { CrawlResultItem as CrawlResultItemType } from '@/models/datasets'
import Checkbox from '@/app/components/base/checkbox'
type Props = {
payload: CrawlResultItemType
isChecked: boolean
isPreview: boolean
onCheckChange: (checked: boolean) => void
onPreview: () => void
}
const CrawledResultItem: FC<Props> = ({
isPreview,
payload,
isChecked,
onCheckChange,
onPreview,
}) => {
const { t } = useTranslation()
const handleCheckChange = useCallback(() => {
onCheckChange(!isChecked)
}, [isChecked, onCheckChange])
return (
<div className={cn(isPreview ? 'border-[#D1E0FF] bg-primary-50 shadow-xs' : 'group hover:bg-gray-100', 'rounded-md px-2 py-[5px] cursor-pointer border border-transparent')}>
<div className='flex items-center h-5'>
<Checkbox className='group-hover:border-2 group-hover:border-primary-600 mr-2 shrink-0' checked={isChecked} onCheck={handleCheckChange} />
<div className='grow w-0 truncate text-sm font-medium text-gray-700' title={payload.title}>{payload.title}</div>
<div onClick={onPreview} className='hidden group-hover:flex items-center h-6 px-2 text-xs rounded-md font-medium text-gray-500 uppercase hover:bg-gray-50'>{t('datasetCreation.stepOne.website.preview')}</div>
</div>
<div className='mt-0.5 truncate pl-6 leading-[18px] text-xs font-normal text-gray-500' title={payload.source_url}>{payload.source_url}</div>
</div>
)
}
export default React.memo(CrawledResultItem)

View File

@ -0,0 +1,87 @@
'use client'
import type { FC } from 'react'
import React, { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import CheckboxWithLabel from './base/checkbox-with-label'
import CrawledResultItem from './crawled-result-item'
import cn from '@/utils/classnames'
import type { CrawlResultItem } from '@/models/datasets'
const I18N_PREFIX = 'datasetCreation.stepOne.website'
type Props = {
className?: string
list: CrawlResultItem[]
checkedList: CrawlResultItem[]
onSelectedChange: (selected: CrawlResultItem[]) => void
onPreview: (payload: CrawlResultItem) => void
usedTime: number
}
const CrawledResult: FC<Props> = ({
className = '',
list,
checkedList,
onSelectedChange,
onPreview,
usedTime,
}) => {
const { t } = useTranslation()
const isCheckAll = checkedList.length === list.length
const handleCheckedAll = useCallback(() => {
if (!isCheckAll)
onSelectedChange(list)
else
onSelectedChange([])
}, [isCheckAll, list, onSelectedChange])
const handleItemCheckChange = useCallback((item: CrawlResultItem) => {
return (checked: boolean) => {
if (checked)
onSelectedChange([...checkedList, item])
else
onSelectedChange(checkedList.filter(checkedItem => checkedItem.source_url !== item.source_url))
}
}, [checkedList, onSelectedChange])
const [previewIndex, setPreviewIndex] = React.useState<number>(-1)
const handlePreview = useCallback((index: number) => {
return () => {
setPreviewIndex(index)
onPreview(list[index])
}
}, [list, onPreview])
return (
<div className={cn(className, 'border-t border-gray-200')}>
<div className='flex items-center justify-between h-[34px] px-4 bg-gray-50 shadow-xs border-b-[0.5px] border-black/8 text-xs font-normal text-gray-700'>
<CheckboxWithLabel
isChecked={isCheckAll}
onChange={handleCheckedAll} label={isCheckAll ? t(`${I18N_PREFIX}.resetAll`) : t(`${I18N_PREFIX}.selectAll`)}
labelClassName='!font-medium'
/>
<div>{t(`${I18N_PREFIX}.scrapTimeInfo`, {
total: list.length,
time: usedTime.toFixed(1),
})}</div>
</div>
<div className='p-2'>
{list.map((item, index) => (
<CrawledResultItem
key={item.source_url}
isPreview={index === previewIndex}
onPreview={handlePreview(index)}
payload={item}
isChecked={checkedList.some(checkedItem => checkedItem.source_url === item.source_url)}
onCheckChange={handleItemCheckChange(item)}
/>
))}
</div>
</div>
)
}
export default React.memo(CrawledResult)

View File

@ -0,0 +1,37 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import cn from '@/utils/classnames'
import { RowStruct } from '@/app/components/base/icons/src/public/other'
type Props = {
className?: string
crawledNum: number
totalNum: number
}
const Crawling: FC<Props> = ({
className = '',
crawledNum,
totalNum,
}) => {
const { t } = useTranslation()
return (
<div className={cn(className, 'border-t border-gray-200')}>
<div className='flex items-center h-[34px] px-4 bg-gray-50 shadow-xs border-b-[0.5px] border-black/8 text-xs font-normal text-gray-700'>
{t('datasetCreation.stepOne.website.totalPageScraped')} {crawledNum}/{totalNum}
</div>
<div className='p-2'>
{['', '', '', ''].map((item, index) => (
<div className='py-[5px]' key={index}>
<RowStruct />
</div>
))}
</div>
</div>
)
}
export default React.memo(Crawling)

View File

@ -0,0 +1,24 @@
import type { CrawlResultItem } from '@/models/datasets'
const result: CrawlResultItem[] = [
{
title: 'Start the frontend Docker container separately',
markdown: 'Markdown 1',
description: 'Description 1',
source_url: 'https://example.com/1',
},
{
title: 'Advanced Tool Integration',
markdown: 'Markdown 2',
description: 'Description 2',
source_url: 'https://example.com/2',
},
{
title: 'Local Source Code Start | English | Dify',
markdown: 'Markdown 3',
description: 'Description 3',
source_url: 'https://example.com/3',
},
]
export default result