add pgvecto_rs support and upgrade SQLAlchemy (#3833)

This commit is contained in:
Jyong 2024-04-29 11:58:17 +08:00 committed by GitHub
parent 975b2fb79e
commit 3e9dbe3e0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 584 additions and 220 deletions

View File

@ -61,19 +61,21 @@ jobs:
- name: Run Workflow - name: Run Workflow
run: dev/pytest/pytest_workflow.sh run: dev/pytest/pytest_workflow.sh
- name: Set up Vector Stores (Weaviate, Qdrant and Milvus) - name: Set up Vector Stores (Weaviate, Qdrant, Milvus, PgVecto-RS)
uses: hoverkraft-tech/compose-action@v2.0.0 uses: hoverkraft-tech/compose-action@v2.0.0
with: with:
compose-file: | compose-file: |
docker/docker-compose.middleware.yaml docker/docker-compose.middleware.yaml
docker/docker-compose.qdrant.yaml docker/docker-compose.qdrant.yaml
docker/docker-compose.milvus.yaml docker/docker-compose.milvus.yaml
docker/docker-compose.pgvecto-rs.yaml
services: | services: |
weaviate weaviate
qdrant qdrant
etcd etcd
minio minio
milvus-standalone milvus-standalone
pgvecto-rs
- name: Test Vector Stores - name: Test Vector Stores
run: dev/pytest/pytest_vdb.sh run: dev/pytest/pytest_vdb.sh

View File

@ -62,7 +62,7 @@ GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-stri
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant, milvus, relyt # Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# Weaviate configuration # Weaviate configuration
@ -92,6 +92,13 @@ RELYT_USER=postgres
RELYT_PASSWORD=postgres RELYT_PASSWORD=postgres
RELYT_DATABASE=postgres RELYT_DATABASE=postgres
# PGVECTO_RS configuration
PGVECTO_RS_HOST=localhost
PGVECTO_RS_PORT=5431
PGVECTO_RS_USER=postgres
PGVECTO_RS_PASSWORD=difyai123456
PGVECTO_RS_DATABASE=postgres
# Upload configuration # Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5 UPLOAD_FILE_BATCH_LIMIT=5

View File

@ -251,6 +251,13 @@ class Config:
self.RELYT_PASSWORD = get_env('RELYT_PASSWORD') self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
self.RELYT_DATABASE = get_env('RELYT_DATABASE') self.RELYT_DATABASE = get_env('RELYT_DATABASE')
# pgvecto rs settings
self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST')
self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT')
self.PGVECTO_RS_USER = get_env('PGVECTO_RS_USER')
self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD')
self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE')
# ------------------------ # ------------------------
# Mail Configurations. # Mail Configurations.
# ------------------------ # ------------------------

View File

@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
vector_type = current_app.config['VECTOR_STORE'] vector_type = current_app.config['VECTOR_STORE']
if vector_type == 'milvus' or vector_type == 'relyt': if vector_type == 'milvus' or vector_type == 'pgvecto_rs' or vector_type == 'relyt':
return { return {
'retrieval_method': [ 'retrieval_method': [
'semantic_search' 'semantic_search'

View File

@ -0,0 +1,12 @@
from uuid import UUID
from numpy import ndarray
from sqlalchemy.orm import DeclarativeBase, Mapped
class CollectionORM(DeclarativeBase):
__tablename__: str
id: Mapped[UUID]
text: Mapped[str]
meta: Mapped[dict]
vector: Mapped[ndarray]

View File

@ -0,0 +1,224 @@
import logging
from typing import Any
from uuid import UUID, uuid4
from numpy import ndarray
from pgvecto_rs.sqlalchemy import Vector
from pydantic import BaseModel, root_validator
from sqlalchemy import Float, String, create_engine, insert, select, text
from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class PgvectoRSConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config PGVECTO_RS_HOST is required")
if not values['port']:
raise ValueError("config PGVECTO_RS_PORT is required")
if not values['user']:
raise ValueError("config PGVECTO_RS_USER is required")
if not values['password']:
raise ValueError("config PGVECTO_RS_PASSWORD is required")
if not values['database']:
raise ValueError("config PGVECTO_RS_DATABASE is required")
return values
class PGVectoRS(BaseVector):
def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int):
super().__init__(collection_name)
self._client_config = config
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
self._client = create_engine(self._url)
with Session(self._client) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
session.commit()
self._fields = []
class _Table(CollectionORM):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True} # noqa: RUF012
id: Mapped[UUID] = mapped_column(
postgresql.UUID(as_uuid=True),
primary_key=True,
)
text: Mapped[str] = mapped_column(String)
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
vector: Mapped[ndarray] = mapped_column(Vector(dim))
self._table = _Table
self._distance_op = "<=>"
def get_type(self) -> str:
return 'pgvecto-rs'
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.create_collection(len(embeddings[0]))
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self._client) as session:
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
id UUID PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
vector vector({dimension}) NOT NULL
) using heap;
""")
session.execute(create_statement)
index_statement = sql_text(f"""
CREATE INDEX IF NOT EXISTS {index_name}
ON {self._collection_name} USING vectors(vector vector_l2_ops)
WITH (options = $$
optimizing.optimizing_threads = 30
segment.max_growing_segment_size = 2000
segment.max_sealed_segment_size = 30000000
[indexing.hnsw]
m=30
ef_construction=500
$$);
""")
session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
pks = []
with Session(self._client) as session:
for document, embedding in zip(documents, embeddings):
pk = uuid4()
session.execute(
insert(self._table).values(
id=pk,
text=document.page_content,
meta=document.metadata,
vector=embedding,
),
)
pks.append(pk)
session.commit()
return pks
def delete_by_document_id(self, document_id: str):
ids = self.get_ids_by_metadata_field('document_id', document_id)
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.commit()
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; "
)
result = session.execute(select_statement).fetchall()
if result:
return [item[0] for item in result]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.commit()
def delete_by_ids(self, ids: list[str]) -> None:
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); "
)
result = session.execute(select_statement, {'doc_ids': ids}).fetchall()
if result:
ids = [item[0] for item in result]
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.commit()
def delete(self) -> None:
with Session(self._client) as session:
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
)
result = session.execute(select_statement).fetchall()
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
with Session(self._client) as session:
stmt = (
select(
self._table,
self._table.vector.op(self._distance_op, return_type=Float)(
query_vector,
).label("distance"),
)
.limit(kwargs.get('top_k', 2))
.order_by("distance")
)
res = session.execute(stmt)
results = [(row[0], row[1]) for row in res]
# Organize results.
docs = []
for record, dis in results:
metadata = record.meta
score = 1 - dis
metadata['score'] = score
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if score > score_threshold:
doc = Document(page_content=record.text,
metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# with Session(self._client) as session:
# select_statement = sql_text(
# f"SELECT text, meta FROM {self._collection_name} WHERE to_tsvector(text) @@ '{query}'::tsquery"
# )
# results = session.execute(select_statement).fetchall()
# if results:
# docs = []
# for result in results:
# doc = Document(page_content=result[0],
# metadata=result[1])
# docs.append(doc)
# return docs
return []

View File

@ -235,7 +235,7 @@ class RelytVector(BaseVector):
docs = [] docs = []
for document, score in results: for document, score in results:
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if score > score_threshold: if 1 - score > score_threshold:
docs.append(document) docs.append(document)
return docs return docs

View File

@ -139,6 +139,31 @@ class Vector:
), ),
group_id=self._dataset.id group_id=self._dataset.id
) )
elif vector_type == "pgvecto_rs":
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = self._dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict = {
"type": 'pgvecto_rs',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
dim = len(self._embeddings.embed_query("pgvecto_rs"))
return PGVectoRS(
collection_name=collection_name,
config=PgvectoRSConfig(
host=config.get('PGVECTO_RS_HOST'),
port=config.get('PGVECTO_RS_PORT'),
user=config.get('PGVECTO_RS_USER'),
password=config.get('PGVECTO_RS_PASSWORD'),
database=config.get('PGVECTO_RS_DATABASE'),
),
dim=dim
)
else: else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")

View File

@ -6,6 +6,7 @@ Create Date: ${create_date}
""" """
from alembic import op from alembic import op
import models as models
import sqlalchemy as sa import sqlalchemy as sa
${imports if imports else ""} ${imports if imports else ""}

View File

@ -1,5 +1,8 @@
from enum import Enum from enum import Enum
from sqlalchemy import CHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
class CreatedByRole(Enum): class CreatedByRole(Enum):
""" """
@ -42,3 +45,27 @@ class CreatedFrom(Enum):
if role.value == value: if role.value == value:
return role return role
raise ValueError(f'invalid createdFrom value {value}') raise ValueError(f'invalid createdFrom value {value}')
class StringUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
return str(value)
else:
return value.hex
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect):
if value is None:
return value
return str(value)

View File

@ -2,9 +2,9 @@ import enum
import json import json
from flask_login import UserMixin from flask_login import UserMixin
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db from extensions.ext_database import db
from models import StringUUID
class AccountStatus(str, enum.Enum): class AccountStatus(str, enum.Enum):
@ -22,7 +22,7 @@ class Account(UserMixin, db.Model):
db.Index('account_email_idx', 'email') db.Index('account_email_idx', 'email')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
email = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False)
password = db.Column(db.String(255), nullable=True) password = db.Column(db.String(255), nullable=True)
@ -128,7 +128,7 @@ class Tenant(db.Model):
db.PrimaryKeyConstraint('id', name='tenant_pkey'), db.PrimaryKeyConstraint('id', name='tenant_pkey'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
encrypt_public_key = db.Column(db.Text) encrypt_public_key = db.Column(db.Text)
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
@ -168,12 +168,12 @@ class TenantAccountJoin(db.Model):
db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join') db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
account_id = db.Column(UUID, nullable=False) account_id = db.Column(StringUUID, nullable=False)
current = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) current = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
role = db.Column(db.String(16), nullable=False, server_default='normal') role = db.Column(db.String(16), nullable=False, server_default='normal')
invited_by = db.Column(UUID, nullable=True) invited_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -186,8 +186,8 @@ class AccountIntegrate(db.Model):
db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id') db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
account_id = db.Column(UUID, nullable=False) account_id = db.Column(StringUUID, nullable=False)
provider = db.Column(db.String(16), nullable=False) provider = db.Column(db.String(16), nullable=False)
open_id = db.Column(db.String(255), nullable=False) open_id = db.Column(db.String(255), nullable=False)
encrypted_token = db.Column(db.String(255), nullable=False) encrypted_token = db.Column(db.String(255), nullable=False)
@ -208,7 +208,7 @@ class InvitationCode(db.Model):
code = db.Column(db.String(32), nullable=False) code = db.Column(db.String(32), nullable=False)
status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying")) status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying"))
used_at = db.Column(db.DateTime) used_at = db.Column(db.DateTime)
used_by_tenant_id = db.Column(UUID) used_by_tenant_id = db.Column(StringUUID)
used_by_account_id = db.Column(UUID) used_by_account_id = db.Column(StringUUID)
deprecated_at = db.Column(db.DateTime) deprecated_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -1,8 +1,7 @@
import enum import enum
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db from extensions.ext_database import db
from models import StringUUID
class APIBasedExtensionPoint(enum.Enum): class APIBasedExtensionPoint(enum.Enum):
@ -19,8 +18,8 @@ class APIBasedExtension(db.Model):
db.Index('api_based_extension_tenant_idx', 'tenant_id'), db.Index('api_based_extension_tenant_idx', 'tenant_id'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
api_endpoint = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False)
api_key = db.Column(db.Text, nullable=False) api_key = db.Column(db.Text, nullable=False)

View File

@ -4,10 +4,11 @@ import pickle
from json import JSONDecodeError from json import JSONDecodeError
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from models import StringUUID
from models.account import Account from models.account import Account
from models.model import App, Tag, TagBinding, UploadFile from models.model import App, Tag, TagBinding, UploadFile
@ -22,8 +23,8 @@ class Dataset(db.Model):
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None] INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=True) description = db.Column(db.Text, nullable=True)
provider = db.Column(db.String(255), nullable=False, provider = db.Column(db.String(255), nullable=False,
@ -33,15 +34,15 @@ class Dataset(db.Model):
data_source_type = db.Column(db.String(255)) data_source_type = db.Column(db.String(255))
indexing_technique = db.Column(db.String(255), nullable=True) indexing_technique = db.Column(db.String(255), nullable=True)
index_struct = db.Column(db.Text, nullable=True) index_struct = db.Column(db.Text, nullable=True)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID, nullable=True) updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
embedding_model = db.Column(db.String(255), nullable=True) embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(UUID, nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True) retrieval_model = db.Column(JSONB, nullable=True)
@property @property
@ -145,13 +146,13 @@ class DatasetProcessRule(db.Model):
db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'), db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'),
) )
id = db.Column(UUID, nullable=False, id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()')) server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(UUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False)
mode = db.Column(db.String(255), nullable=False, mode = db.Column(db.String(255), nullable=False,
server_default=db.text("'automatic'::character varying")) server_default=db.text("'automatic'::character varying"))
rules = db.Column(db.Text, nullable=True) rules = db.Column(db.Text, nullable=True)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -197,19 +198,19 @@ class Document(db.Model):
) )
# initial fields # initial fields
id = db.Column(UUID, nullable=False, id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()')) server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(UUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False) position = db.Column(db.Integer, nullable=False)
data_source_type = db.Column(db.String(255), nullable=False) data_source_type = db.Column(db.String(255), nullable=False)
data_source_info = db.Column(db.Text, nullable=True) data_source_info = db.Column(db.Text, nullable=True)
dataset_process_rule_id = db.Column(UUID, nullable=True) dataset_process_rule_id = db.Column(StringUUID, nullable=True)
batch = db.Column(db.String(255), nullable=False) batch = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
created_from = db.Column(db.String(255), nullable=False) created_from = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_api_request_id = db.Column(UUID, nullable=True) created_api_request_id = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -234,7 +235,7 @@ class Document(db.Model):
# pause # pause
is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
paused_by = db.Column(UUID, nullable=True) paused_by = db.Column(StringUUID, nullable=True)
paused_at = db.Column(db.DateTime, nullable=True) paused_at = db.Column(db.DateTime, nullable=True)
# error # error
@ -247,11 +248,11 @@ class Document(db.Model):
enabled = db.Column(db.Boolean, nullable=False, enabled = db.Column(db.Boolean, nullable=False,
server_default=db.text('true')) server_default=db.text('true'))
disabled_at = db.Column(db.DateTime, nullable=True) disabled_at = db.Column(db.DateTime, nullable=True)
disabled_by = db.Column(UUID, nullable=True) disabled_by = db.Column(StringUUID, nullable=True)
archived = db.Column(db.Boolean, nullable=False, archived = db.Column(db.Boolean, nullable=False,
server_default=db.text('false')) server_default=db.text('false'))
archived_reason = db.Column(db.String(255), nullable=True) archived_reason = db.Column(db.String(255), nullable=True)
archived_by = db.Column(UUID, nullable=True) archived_by = db.Column(StringUUID, nullable=True)
archived_at = db.Column(db.DateTime, nullable=True) archived_at = db.Column(db.DateTime, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -356,11 +357,11 @@ class DocumentSegment(db.Model):
) )
# initial fields # initial fields
id = db.Column(UUID, nullable=False, id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()')) server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(UUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(UUID, nullable=False) document_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False) position = db.Column(db.Integer, nullable=False)
content = db.Column(db.Text, nullable=False) content = db.Column(db.Text, nullable=False)
answer = db.Column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True)
@ -377,13 +378,13 @@ class DocumentSegment(db.Model):
enabled = db.Column(db.Boolean, nullable=False, enabled = db.Column(db.Boolean, nullable=False,
server_default=db.text('true')) server_default=db.text('true'))
disabled_at = db.Column(db.DateTime, nullable=True) disabled_at = db.Column(db.DateTime, nullable=True)
disabled_by = db.Column(UUID, nullable=True) disabled_by = db.Column(StringUUID, nullable=True)
status = db.Column(db.String(255), nullable=False, status = db.Column(db.String(255), nullable=False,
server_default=db.text("'waiting'::character varying")) server_default=db.text("'waiting'::character varying"))
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID, nullable=True) updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
indexing_at = db.Column(db.DateTime, nullable=True) indexing_at = db.Column(db.DateTime, nullable=True)
@ -421,9 +422,9 @@ class AppDatasetJoin(db.Model):
db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'), db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'),
) )
id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(UUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property @property
@ -438,13 +439,13 @@ class DatasetQuery(db.Model):
db.Index('dataset_query_dataset_id_idx', 'dataset_id'), db.Index('dataset_query_dataset_id_idx', 'dataset_id'),
) )
id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(UUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False)
content = db.Column(db.Text, nullable=False) content = db.Column(db.Text, nullable=False)
source = db.Column(db.String(255), nullable=False) source = db.Column(db.String(255), nullable=False)
source_app_id = db.Column(UUID, nullable=True) source_app_id = db.Column(StringUUID, nullable=True)
created_by_role = db.Column(db.String, nullable=False) created_by_role = db.Column(db.String, nullable=False)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@ -455,8 +456,8 @@ class DatasetKeywordTable(db.Model):
db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'), db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'),
) )
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(UUID, nullable=False, unique=True) dataset_id = db.Column(StringUUID, nullable=False, unique=True)
keyword_table = db.Column(db.Text, nullable=False) keyword_table = db.Column(db.Text, nullable=False)
data_source_type = db.Column(db.String(255), nullable=False, data_source_type = db.Column(db.String(255), nullable=False,
server_default=db.text("'database'::character varying")) server_default=db.text("'database'::character varying"))
@ -501,7 +502,7 @@ class Embedding(db.Model):
db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx') db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx')
) )
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
model_name = db.Column(db.String(40), nullable=False, model_name = db.Column(db.String(40), nullable=False,
server_default=db.text("'text-embedding-ada-002'::character varying")) server_default=db.text("'text-embedding-ada-002'::character varying"))
hash = db.Column(db.String(64), nullable=False) hash = db.Column(db.String(64), nullable=False)
@ -525,7 +526,7 @@ class DatasetCollectionBinding(db.Model):
) )
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(40), nullable=False)
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)

View File

@ -7,13 +7,13 @@ from typing import Optional
from flask import current_app, request from flask import current_app, request
from flask_login import UserMixin from flask_login import UserMixin
from sqlalchemy import Float, text from sqlalchemy import Float, text
from sqlalchemy.dialects.postgresql import UUID
from core.file.tool_file_parser import ToolFileParser from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser from core.file.upload_file_parser import UploadFileParser
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import generate_string from libs.helper import generate_string
from . import StringUUID
from .account import Account, Tenant from .account import Account, Tenant
@ -56,15 +56,15 @@ class App(db.Model):
db.Index('app_tenant_id_idx', 'tenant_id') db.Index('app_tenant_id_idx', 'tenant_id')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
mode = db.Column(db.String(255), nullable=False) mode = db.Column(db.String(255), nullable=False)
icon = db.Column(db.String(255)) icon = db.Column(db.String(255))
icon_background = db.Column(db.String(255)) icon_background = db.Column(db.String(255))
app_model_config_id = db.Column(UUID, nullable=True) app_model_config_id = db.Column(StringUUID, nullable=True)
workflow_id = db.Column(UUID, nullable=True) workflow_id = db.Column(StringUUID, nullable=True)
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
enable_site = db.Column(db.Boolean, nullable=False) enable_site = db.Column(db.Boolean, nullable=False)
enable_api = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False)
@ -207,8 +207,8 @@ class AppModelConfig(db.Model):
db.Index('app_app_id_idx', 'app_id') db.Index('app_app_id_idx', 'app_id')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
provider = db.Column(db.String(255), nullable=True) provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True)
configs = db.Column(db.JSON, nullable=True) configs = db.Column(db.JSON, nullable=True)
@ -430,8 +430,8 @@ class RecommendedApp(db.Model):
db.Index('recommended_app_is_listed_idx', 'is_listed', 'language') db.Index('recommended_app_is_listed_idx', 'is_listed', 'language')
) )
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
description = db.Column(db.JSON, nullable=False) description = db.Column(db.JSON, nullable=False)
copyright = db.Column(db.String(255), nullable=False) copyright = db.Column(db.String(255), nullable=False)
privacy_policy = db.Column(db.String(255), nullable=False) privacy_policy = db.Column(db.String(255), nullable=False)
@ -458,10 +458,10 @@ class InstalledApp(db.Model):
db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app') db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
app_owner_tenant_id = db.Column(UUID, nullable=False) app_owner_tenant_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False, default=0) position = db.Column(db.Integer, nullable=False, default=0)
is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
last_used_at = db.Column(db.DateTime, nullable=True) last_used_at = db.Column(db.DateTime, nullable=True)
@ -486,9 +486,9 @@ class Conversation(db.Model):
db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id') db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
app_model_config_id = db.Column(UUID, nullable=True) app_model_config_id = db.Column(StringUUID, nullable=True)
model_provider = db.Column(db.String(255), nullable=True) model_provider = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text) override_model_configs = db.Column(db.Text)
model_id = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True)
@ -502,10 +502,10 @@ class Conversation(db.Model):
status = db.Column(db.String(255), nullable=False) status = db.Column(db.String(255), nullable=False)
invoke_from = db.Column(db.String(255), nullable=True) invoke_from = db.Column(db.String(255), nullable=True)
from_source = db.Column(db.String(255), nullable=False) from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(UUID) from_end_user_id = db.Column(StringUUID)
from_account_id = db.Column(UUID) from_account_id = db.Column(StringUUID)
read_at = db.Column(db.DateTime) read_at = db.Column(db.DateTime)
read_account_id = db.Column(UUID) read_account_id = db.Column(StringUUID)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -626,12 +626,12 @@ class Message(db.Model):
db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'), db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
model_provider = db.Column(db.String(255), nullable=True) model_provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text) override_model_configs = db.Column(db.Text)
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False)
inputs = db.Column(db.JSON) inputs = db.Column(db.JSON)
query = db.Column(db.Text, nullable=False) query = db.Column(db.Text, nullable=False)
message = db.Column(db.JSON, nullable=False) message = db.Column(db.JSON, nullable=False)
@ -650,12 +650,12 @@ class Message(db.Model):
message_metadata = db.Column(db.Text) message_metadata = db.Column(db.Text)
invoke_from = db.Column(db.String(255), nullable=True) invoke_from = db.Column(db.String(255), nullable=True)
from_source = db.Column(db.String(255), nullable=False) from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(UUID) from_end_user_id = db.Column(StringUUID)
from_account_id = db.Column(UUID) from_account_id = db.Column(StringUUID)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
workflow_run_id = db.Column(UUID) workflow_run_id = db.Column(StringUUID)
@property @property
def re_sign_file_url_answer(self) -> str: def re_sign_file_url_answer(self) -> str:
@ -846,15 +846,15 @@ class MessageFeedback(db.Model):
db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating') db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(UUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False)
message_id = db.Column(UUID, nullable=False) message_id = db.Column(StringUUID, nullable=False)
rating = db.Column(db.String(255), nullable=False) rating = db.Column(db.String(255), nullable=False)
content = db.Column(db.Text) content = db.Column(db.Text)
from_source = db.Column(db.String(255), nullable=False) from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(UUID) from_end_user_id = db.Column(StringUUID)
from_account_id = db.Column(UUID) from_account_id = db.Column(StringUUID)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -872,15 +872,15 @@ class MessageFile(db.Model):
db.Index('message_file_created_by_idx', 'created_by') db.Index('message_file_created_by_idx', 'created_by')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False) message_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False) type = db.Column(db.String(255), nullable=False)
transfer_method = db.Column(db.String(255), nullable=False) transfer_method = db.Column(db.String(255), nullable=False)
url = db.Column(db.Text, nullable=True) url = db.Column(db.Text, nullable=True)
belongs_to = db.Column(db.String(255), nullable=True) belongs_to = db.Column(db.String(255), nullable=True)
upload_file_id = db.Column(UUID, nullable=True) upload_file_id = db.Column(StringUUID, nullable=True)
created_by_role = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -893,14 +893,14 @@ class MessageAnnotation(db.Model):
db.Index('message_annotation_message_idx', 'message_id') db.Index('message_annotation_message_idx', 'message_id')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True) conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True)
message_id = db.Column(UUID, nullable=True) message_id = db.Column(StringUUID, nullable=True)
question = db.Column(db.Text, nullable=True) question = db.Column(db.Text, nullable=True)
content = db.Column(db.Text, nullable=False) content = db.Column(db.Text, nullable=False)
hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0')) hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
account_id = db.Column(UUID, nullable=False) account_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -925,15 +925,15 @@ class AppAnnotationHitHistory(db.Model):
db.Index('app_annotation_hit_histories_message_idx', 'message_id'), db.Index('app_annotation_hit_histories_message_idx', 'message_id'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
annotation_id = db.Column(UUID, nullable=False) annotation_id = db.Column(StringUUID, nullable=False)
source = db.Column(db.Text, nullable=False) source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False)
account_id = db.Column(UUID, nullable=False) account_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
score = db.Column(Float, nullable=False, server_default=db.text('0')) score = db.Column(Float, nullable=False, server_default=db.text('0'))
message_id = db.Column(UUID, nullable=False) message_id = db.Column(StringUUID, nullable=False)
annotation_question = db.Column(db.Text, nullable=False) annotation_question = db.Column(db.Text, nullable=False)
annotation_content = db.Column(db.Text, nullable=False) annotation_content = db.Column(db.Text, nullable=False)
@ -957,13 +957,13 @@ class AppAnnotationSetting(db.Model):
db.Index('app_annotation_settings_app_idx', 'app_id') db.Index('app_annotation_settings_app_idx', 'app_id')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
score_threshold = db.Column(Float, nullable=False, server_default=db.text('0')) score_threshold = db.Column(Float, nullable=False, server_default=db.text('0'))
collection_binding_id = db.Column(UUID, nullable=False) collection_binding_id = db.Column(StringUUID, nullable=False)
created_user_id = db.Column(UUID, nullable=False) created_user_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_user_id = db.Column(UUID, nullable=False) updated_user_id = db.Column(StringUUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property @property
@ -995,9 +995,9 @@ class OperationLog(db.Model):
db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action') db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
account_id = db.Column(UUID, nullable=False) account_id = db.Column(StringUUID, nullable=False)
action = db.Column(db.String(255), nullable=False) action = db.Column(db.String(255), nullable=False)
content = db.Column(db.JSON) content = db.Column(db.JSON)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -1013,9 +1013,9 @@ class EndUser(UserMixin, db.Model):
db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'), db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(UUID, nullable=True) app_id = db.Column(StringUUID, nullable=True)
type = db.Column(db.String(255), nullable=False) type = db.Column(db.String(255), nullable=False)
external_user_id = db.Column(db.String(255), nullable=True) external_user_id = db.Column(db.String(255), nullable=True)
name = db.Column(db.String(255)) name = db.Column(db.String(255))
@ -1033,8 +1033,8 @@ class Site(db.Model):
db.Index('site_code_idx', 'code', 'status') db.Index('site_code_idx', 'code', 'status')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
title = db.Column(db.String(255), nullable=False) title = db.Column(db.String(255), nullable=False)
icon = db.Column(db.String(255)) icon = db.Column(db.String(255))
icon_background = db.Column(db.String(255)) icon_background = db.Column(db.String(255))
@ -1074,9 +1074,9 @@ class ApiToken(db.Model):
db.Index('api_token_tenant_idx', 'tenant_id', 'type') db.Index('api_token_tenant_idx', 'tenant_id', 'type')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=True) app_id = db.Column(StringUUID, nullable=True)
tenant_id = db.Column(UUID, nullable=True) tenant_id = db.Column(StringUUID, nullable=True)
type = db.Column(db.String(16), nullable=False) type = db.Column(db.String(16), nullable=False)
token = db.Column(db.String(255), nullable=False) token = db.Column(db.String(255), nullable=False)
last_used_at = db.Column(db.DateTime, nullable=True) last_used_at = db.Column(db.DateTime, nullable=True)
@ -1099,8 +1099,8 @@ class UploadFile(db.Model):
db.Index('upload_file_tenant_idx', 'tenant_id') db.Index('upload_file_tenant_idx', 'tenant_id')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
storage_type = db.Column(db.String(255), nullable=False) storage_type = db.Column(db.String(255), nullable=False)
key = db.Column(db.String(255), nullable=False) key = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
@ -1108,10 +1108,10 @@ class UploadFile(db.Model):
extension = db.Column(db.String(255), nullable=False) extension = db.Column(db.String(255), nullable=False)
mime_type = db.Column(db.String(255), nullable=True) mime_type = db.Column(db.String(255), nullable=True)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying")) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
used = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
used_by = db.Column(UUID, nullable=True) used_by = db.Column(StringUUID, nullable=True)
used_at = db.Column(db.DateTime, nullable=True) used_at = db.Column(db.DateTime, nullable=True)
hash = db.Column(db.String(255), nullable=True) hash = db.Column(db.String(255), nullable=True)
@ -1123,9 +1123,9 @@ class ApiRequest(db.Model):
db.Index('api_request_token_idx', 'tenant_id', 'api_token_id') db.Index('api_request_token_idx', 'tenant_id', 'api_token_id')
) )
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
api_token_id = db.Column(UUID, nullable=False) api_token_id = db.Column(StringUUID, nullable=False)
path = db.Column(db.String(255), nullable=False) path = db.Column(db.String(255), nullable=False)
request = db.Column(db.Text, nullable=True) request = db.Column(db.Text, nullable=True)
response = db.Column(db.Text, nullable=True) response = db.Column(db.Text, nullable=True)
@ -1140,8 +1140,8 @@ class MessageChain(db.Model):
db.Index('message_chain_message_id_idx', 'message_id') db.Index('message_chain_message_id_idx', 'message_id')
) )
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False) message_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False) type = db.Column(db.String(255), nullable=False)
input = db.Column(db.Text, nullable=True) input = db.Column(db.Text, nullable=True)
output = db.Column(db.Text, nullable=True) output = db.Column(db.Text, nullable=True)
@ -1156,9 +1156,9 @@ class MessageAgentThought(db.Model):
db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'), db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'),
) )
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False) message_id = db.Column(StringUUID, nullable=False)
message_chain_id = db.Column(UUID, nullable=True) message_chain_id = db.Column(StringUUID, nullable=True)
position = db.Column(db.Integer, nullable=False) position = db.Column(db.Integer, nullable=False)
thought = db.Column(db.Text, nullable=True) thought = db.Column(db.Text, nullable=True)
tool = db.Column(db.Text, nullable=True) tool = db.Column(db.Text, nullable=True)
@ -1166,7 +1166,7 @@ class MessageAgentThought(db.Model):
tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
tool_input = db.Column(db.Text, nullable=True) tool_input = db.Column(db.Text, nullable=True)
observation = db.Column(db.Text, nullable=True) observation = db.Column(db.Text, nullable=True)
# plugin_id = db.Column(UUID, nullable=True) ## for future design # plugin_id = db.Column(StringUUID, nullable=True) ## for future design
tool_process_data = db.Column(db.Text, nullable=True) tool_process_data = db.Column(db.Text, nullable=True)
message = db.Column(db.Text, nullable=True) message = db.Column(db.Text, nullable=True)
message_token = db.Column(db.Integer, nullable=True) message_token = db.Column(db.Integer, nullable=True)
@ -1182,7 +1182,7 @@ class MessageAgentThought(db.Model):
currency = db.Column(db.String, nullable=True) currency = db.Column(db.String, nullable=True)
latency = db.Column(db.Float, nullable=True) latency = db.Column(db.Float, nullable=True)
created_by_role = db.Column(db.String, nullable=False) created_by_role = db.Column(db.String, nullable=False)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property @property
@ -1273,15 +1273,15 @@ class DatasetRetrieverResource(db.Model):
db.Index('dataset_retriever_resource_message_id_idx', 'message_id'), db.Index('dataset_retriever_resource_message_id_idx', 'message_id'),
) )
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False) message_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False) position = db.Column(db.Integer, nullable=False)
dataset_id = db.Column(UUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False)
dataset_name = db.Column(db.Text, nullable=False) dataset_name = db.Column(db.Text, nullable=False)
document_id = db.Column(UUID, nullable=False) document_id = db.Column(StringUUID, nullable=False)
document_name = db.Column(db.Text, nullable=False) document_name = db.Column(db.Text, nullable=False)
data_source_type = db.Column(db.Text, nullable=False) data_source_type = db.Column(db.Text, nullable=False)
segment_id = db.Column(UUID, nullable=False) segment_id = db.Column(StringUUID, nullable=False)
score = db.Column(db.Float, nullable=True) score = db.Column(db.Float, nullable=True)
content = db.Column(db.Text, nullable=False) content = db.Column(db.Text, nullable=False)
hit_count = db.Column(db.Integer, nullable=True) hit_count = db.Column(db.Integer, nullable=True)
@ -1289,7 +1289,7 @@ class DatasetRetrieverResource(db.Model):
segment_position = db.Column(db.Integer, nullable=True) segment_position = db.Column(db.Integer, nullable=True)
index_node_hash = db.Column(db.Text, nullable=True) index_node_hash = db.Column(db.Text, nullable=True)
retriever_from = db.Column(db.Text, nullable=False) retriever_from = db.Column(db.Text, nullable=False)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@ -1303,11 +1303,11 @@ class Tag(db.Model):
TAG_TYPE_LIST = ['knowledge', 'app'] TAG_TYPE_LIST = ['knowledge', 'app']
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=True) tenant_id = db.Column(StringUUID, nullable=True)
type = db.Column(db.String(16), nullable=False) type = db.Column(db.String(16), nullable=False)
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -1319,9 +1319,9 @@ class TagBinding(db.Model):
db.Index('tag_bind_tag_id_idx', 'tag_id'), db.Index('tag_bind_tag_id_idx', 'tag_id'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=True) tenant_id = db.Column(StringUUID, nullable=True)
tag_id = db.Column(UUID, nullable=True) tag_id = db.Column(StringUUID, nullable=True)
target_id = db.Column(UUID, nullable=True) target_id = db.Column(StringUUID, nullable=True)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -1,8 +1,7 @@
from enum import Enum from enum import Enum
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db from extensions.ext_database import db
from models import StringUUID
class ProviderType(Enum): class ProviderType(Enum):
@ -46,8 +45,8 @@ class Provider(db.Model):
db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota') db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(40), nullable=False)
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
encrypted_config = db.Column(db.Text, nullable=True) encrypted_config = db.Column(db.Text, nullable=True)
@ -93,8 +92,8 @@ class ProviderModel(db.Model):
db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name') db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type = db.Column(db.String(40), nullable=False)
@ -111,8 +110,8 @@ class TenantDefaultModel(db.Model):
db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'), db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(40), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type = db.Column(db.String(40), nullable=False)
@ -127,8 +126,8 @@ class TenantPreferredModelProvider(db.Model):
db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'), db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(40), nullable=False)
preferred_provider_type = db.Column(db.String(40), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -142,10 +141,10 @@ class ProviderOrder(db.Model):
db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'), db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(40), nullable=False)
account_id = db.Column(UUID, nullable=False) account_id = db.Column(StringUUID, nullable=False)
payment_product_id = db.Column(db.String(191), nullable=False) payment_product_id = db.Column(db.String(191), nullable=False)
payment_id = db.Column(db.String(191)) payment_id = db.Column(db.String(191))
transaction_id = db.Column(db.String(191)) transaction_id = db.Column(db.String(191))

View File

@ -1,6 +1,7 @@
from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db from extensions.ext_database import db
from models import StringUUID
class DataSourceBinding(db.Model): class DataSourceBinding(db.Model):
@ -11,8 +12,8 @@ class DataSourceBinding(db.Model):
db.Index('source_info_idx', "source_info", postgresql_using='gin') db.Index('source_info_idx', "source_info", postgresql_using='gin')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
access_token = db.Column(db.String(255), nullable=False) access_token = db.Column(db.String(255), nullable=False)
provider = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False)
source_info = db.Column(JSONB, nullable=False) source_info = db.Column(JSONB, nullable=False)

View File

@ -1,9 +1,8 @@
import json import json
from enum import Enum from enum import Enum
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db from extensions.ext_database import db
from models import StringUUID
class ToolProviderName(Enum): class ToolProviderName(Enum):
@ -24,8 +23,8 @@ class ToolProvider(db.Model):
db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
tool_name = db.Column(db.String(40), nullable=False) tool_name = db.Column(db.String(40), nullable=False)
encrypted_credentials = db.Column(db.Text, nullable=True) encrypted_credentials = db.Column(db.Text, nullable=True)
is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))

View File

@ -1,12 +1,12 @@
import json import json
from sqlalchemy import ForeignKey from sqlalchemy import ForeignKey
from sqlalchemy.dialects.postgresql import UUID
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiBasedToolBundle from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType from core.tools.entities.tool_entities import ApiProviderSchemaType
from extensions.ext_database import db from extensions.ext_database import db
from models import StringUUID
from models.model import Account, App, Tenant from models.model import Account, App, Tenant
@ -22,11 +22,11 @@ class BuiltinToolProvider(db.Model):
) )
# id of the tool provider # id of the tool provider
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# id of the tenant # id of the tenant
tenant_id = db.Column(UUID, nullable=True) tenant_id = db.Column(StringUUID, nullable=True)
# who created this tool provider # who created this tool provider
user_id = db.Column(UUID, nullable=False) user_id = db.Column(StringUUID, nullable=False)
# name of the tool provider # name of the tool provider
provider = db.Column(db.String(40), nullable=False) provider = db.Column(db.String(40), nullable=False)
# credential of the tool provider # credential of the tool provider
@ -49,11 +49,11 @@ class PublishedAppTool(db.Model):
) )
# id of the tool provider # id of the tool provider
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# id of the app # id of the app
app_id = db.Column(UUID, ForeignKey('apps.id'), nullable=False) app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False)
# who published this tool # who published this tool
user_id = db.Column(UUID, nullable=False) user_id = db.Column(StringUUID, nullable=False)
# description of the tool, stored in i18n format, for human # description of the tool, stored in i18n format, for human
description = db.Column(db.Text, nullable=False) description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM # llm_description of the tool, for LLM
@ -87,7 +87,7 @@ class ApiToolProvider(db.Model):
db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider') db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# name of the api provider # name of the api provider
name = db.Column(db.String(40), nullable=False) name = db.Column(db.String(40), nullable=False)
# icon # icon
@ -96,9 +96,9 @@ class ApiToolProvider(db.Model):
schema = db.Column(db.Text, nullable=False) schema = db.Column(db.Text, nullable=False)
schema_type_str = db.Column(db.String(40), nullable=False) schema_type_str = db.Column(db.String(40), nullable=False)
# who created this tool # who created this tool
user_id = db.Column(UUID, nullable=False) user_id = db.Column(StringUUID, nullable=False)
# tenant id # tenant id
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
# description of the provider # description of the provider
description = db.Column(db.Text, nullable=False) description = db.Column(db.Text, nullable=False)
# json format tools # json format tools
@ -140,11 +140,11 @@ class ToolModelInvoke(db.Model):
db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'), db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# who invoke this tool # who invoke this tool
user_id = db.Column(UUID, nullable=False) user_id = db.Column(StringUUID, nullable=False)
# tenant id # tenant id
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
# provider # provider
provider = db.Column(db.String(40), nullable=False) provider = db.Column(db.String(40), nullable=False)
# type # type
@ -180,13 +180,13 @@ class ToolConversationVariables(db.Model):
db.Index('conversation_id_idx', 'conversation_id'), db.Index('conversation_id_idx', 'conversation_id'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# conversation user id # conversation user id
user_id = db.Column(UUID, nullable=False) user_id = db.Column(StringUUID, nullable=False)
# tenant id # tenant id
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
# conversation id # conversation id
conversation_id = db.Column(UUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False)
# variables pool # variables pool
variables_str = db.Column(db.Text, nullable=False) variables_str = db.Column(db.Text, nullable=False)
@ -208,13 +208,13 @@ class ToolFile(db.Model):
db.Index('tool_file_conversation_id_idx', 'conversation_id'), db.Index('tool_file_conversation_id_idx', 'conversation_id'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# conversation user id # conversation user id
user_id = db.Column(UUID, nullable=False) user_id = db.Column(StringUUID, nullable=False)
# tenant id # tenant id
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
# conversation id # conversation id
conversation_id = db.Column(UUID, nullable=True) conversation_id = db.Column(StringUUID, nullable=True)
# file key # file key
file_key = db.Column(db.String(255), nullable=False) file_key = db.Column(db.String(255), nullable=False)
# mime type # mime type

View File

@ -1,6 +1,6 @@
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db from extensions.ext_database import db
from models import StringUUID
from models.model import Message from models.model import Message
@ -11,11 +11,11 @@ class SavedMessage(db.Model):
db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'), db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
message_id = db.Column(UUID, nullable=False) message_id = db.Column(StringUUID, nullable=False)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property @property
@ -30,9 +30,9 @@ class PinnedConversation(db.Model):
db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'), db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(UUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -2,10 +2,9 @@ import json
from enum import Enum from enum import Enum
from typing import Optional, Union from typing import Optional, Union
from sqlalchemy.dialects.postgresql import UUID
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from extensions.ext_database import db from extensions.ext_database import db
from models import StringUUID
from models.account import Account from models.account import Account
@ -102,16 +101,16 @@ class Workflow(db.Model):
db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False) type = db.Column(db.String(255), nullable=False)
version = db.Column(db.String(255), nullable=False) version = db.Column(db.String(255), nullable=False)
graph = db.Column(db.Text) graph = db.Column(db.Text)
features = db.Column(db.Text) features = db.Column(db.Text)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID) updated_by = db.Column(StringUUID)
updated_at = db.Column(db.DateTime) updated_at = db.Column(db.DateTime)
@property @property
@ -245,11 +244,11 @@ class WorkflowRun(db.Model):
db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'), db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
sequence_number = db.Column(db.Integer, nullable=False) sequence_number = db.Column(db.Integer, nullable=False)
workflow_id = db.Column(UUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False) type = db.Column(db.String(255), nullable=False)
triggered_from = db.Column(db.String(255), nullable=False) triggered_from = db.Column(db.String(255), nullable=False)
version = db.Column(db.String(255), nullable=False) version = db.Column(db.String(255), nullable=False)
@ -262,7 +261,7 @@ class WorkflowRun(db.Model):
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
total_steps = db.Column(db.Integer, server_default=db.text('0')) total_steps = db.Column(db.Integer, server_default=db.text('0'))
created_by_role = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
finished_at = db.Column(db.DateTime) finished_at = db.Column(db.DateTime)
@ -404,12 +403,12 @@ class WorkflowNodeExecution(db.Model):
'triggered_from', 'node_id'), 'triggered_from', 'node_id'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
workflow_id = db.Column(UUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False)
triggered_from = db.Column(db.String(255), nullable=False) triggered_from = db.Column(db.String(255), nullable=False)
workflow_run_id = db.Column(UUID) workflow_run_id = db.Column(StringUUID)
index = db.Column(db.Integer, nullable=False) index = db.Column(db.Integer, nullable=False)
predecessor_node_id = db.Column(db.String(255)) predecessor_node_id = db.Column(db.String(255))
node_id = db.Column(db.String(255), nullable=False) node_id = db.Column(db.String(255), nullable=False)
@ -424,7 +423,7 @@ class WorkflowNodeExecution(db.Model):
execution_metadata = db.Column(db.Text) execution_metadata = db.Column(db.Text)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
created_by_role = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
finished_at = db.Column(db.DateTime) finished_at = db.Column(db.DateTime)
@property @property
@ -529,14 +528,14 @@ class WorkflowAppLog(db.Model):
db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'), db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'),
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(UUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
workflow_id = db.Column(UUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False)
workflow_run_id = db.Column(UUID, nullable=False) workflow_run_id = db.Column(StringUUID, nullable=False)
created_from = db.Column(db.String(255), nullable=False) created_from = db.Column(db.String(255), nullable=False)
created_by_role = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property @property

View File

@ -1,7 +1,7 @@
beautifulsoup4==4.12.2 beautifulsoup4==4.12.2
flask~=3.0.1 flask~=3.0.1
Flask-SQLAlchemy~=3.0.5 Flask-SQLAlchemy~=3.0.5
SQLAlchemy~=1.4.28 SQLAlchemy~=2.0.29
Flask-Compress~=1.14 Flask-Compress~=1.14
flask-login~=0.6.3 flask-login~=0.6.3
flask-migrate~=4.0.5 flask-migrate~=4.0.5

View File

@ -0,0 +1,37 @@
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)
class TestPgvectoRSVector(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = PGVectoRS(
collection_name=self.collection_name.lower(),
config=PgvectoRSConfig(
host='localhost',
port=5431,
user='postgres',
password='difyai123456',
database='dify',
),
dim=128
)
def search_by_full_text(self):
# pgvecto rs only support english text search, So its not open for now
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def delete_by_document_id(self):
self.vector.delete_by_document_id(document_id=self.example_doc_id)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
assert len(ids) == 1
def test_pgvecot_rs(setup_mock_redis):
TestPgvectoRSVector().run_all_tests()

View File

@ -45,7 +45,7 @@ class AbstractVectorTest:
def __init__(self): def __init__(self):
self.vector = None self.vector = None
self.dataset_id = str(uuid.uuid4()) self.dataset_id = str(uuid.uuid4())
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + '_test'
self.example_doc_id = str(uuid.uuid4()) self.example_doc_id = str(uuid.uuid4())
self.example_embedding = [1.001 * i for i in range(128)] self.example_embedding = [1.001 * i for i in range(128)]

View File

@ -0,0 +1,24 @@
version: '3'
services:
# The pgvecto—rs database.
pgvecto-rs:
image: tensorchord/pgvecto-rs:pg16-v0.2.0
restart: always
environment:
PGUSER: postgres
# The password for the default postgres user.
POSTGRES_PASSWORD: difyai123456
# The name of the default postgres database.
POSTGRES_DB: dify
# postgres data directory
PGDATA: /var/lib/postgresql/data/pgdata
volumes:
- ./volumes/pgvectors/data:/var/lib/postgresql/data
# uncomment to expose db(postgresql) port to host
ports:
- "5431:5432"
healthcheck:
test: [ "CMD", "pg_isready" ]
interval: 1s
timeout: 3s
retries: 30