compatible with original provider name

This commit is contained in:
takatost 2024-11-01 00:00:53 -07:00
parent d0c53fabca
commit b75dce5d0a
5 changed files with 215 additions and 1 deletions

View File

@ -25,6 +25,7 @@ from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel from models.provider import Provider, ProviderModel
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
from services.plugin.data_migration import PluginDataMigration
@click.command("reset-password", help="Reset the account password.") @click.command("reset-password", help="Reset the account password.")
@ -642,6 +643,18 @@ where sites.id is null limit 1000"""
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
def migrate_data_for_plugin():
"""
Migrate data for plugin.
"""
click.echo(click.style("Starting migrate data for plugin.", fg="white"))
PluginDataMigration.migrate()
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
def register_commands(app): def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
@ -652,3 +665,4 @@ def register_commands(app):
app.cli.add_command(create_tenant) app.cli.add_command(create_tenant)
app.cli.add_command(upgrade_db) app.cli.add_command(upgrade_db)
app.cli.add_command(fix_app_site_missing) app.cli.add_command(fix_app_site_missing)
app.cli.add_command(migrate_data_for_plugin)

View File

@ -1,4 +1,5 @@
from core.app.app_config.entities import ModelConfigEntity from core.app.app_config.entities import ModelConfigEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
@ -53,7 +54,15 @@ class ModelConfigManager:
model_provider_factory = ModelProviderFactory(tenant_id) model_provider_factory = ModelProviderFactory(tenant_id)
provider_entities = model_provider_factory.get_providers() provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities] model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: if "provider" not in config["model"]:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
if "/" not in config["model"]["provider"]:
config["model"]["provider"] = (
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
)
if config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name # model.name

View File

@ -9,6 +9,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.entities import DEFAULT_PLUGIN_ID
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import ( from core.entities.provider_entities import (
CustomConfiguration, CustomConfiguration,
@ -1047,6 +1048,9 @@ class ProviderConfigurations(BaseModel):
return list(self.values()) return list(self.values())
def __getitem__(self, key): def __getitem__(self, key):
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
return self.configurations[key] return self.configurations[key]
def __setitem__(self, key, value): def __setitem__(self, key, value):
@ -1059,6 +1063,9 @@ class ProviderConfigurations(BaseModel):
return iter(self.configurations.values()) return iter(self.configurations.values())
def get(self, key, default=None): def get(self, key, default=None):
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
return self.configurations.get(key, default) return self.configurations.get(key, default)

View File

View File

@ -0,0 +1,184 @@
import json
import logging
import click
from core.entities import DEFAULT_PLUGIN_ID
from extensions.ext_database import db
logger = logging.getLogger(__name__)
class PluginDataMigration:
@classmethod
def migrate(cls) -> None:
cls.migrate_db_records("providers", "provider_name") # large table
cls.migrate_db_records("provider_models", "provider_name")
cls.migrate_db_records("provider_orders", "provider_name")
cls.migrate_db_records("tenant_default_models", "provider_name")
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
cls.migrate_db_records("provider_model_settings", "provider_name")
cls.migrate_db_records("load_balancing_model_configs", "provider_name")
cls.migrate_datasets()
cls.migrate_db_records("embeddings", "provider_name") # large table
cls.migrate_db_records("dataset_collection_bindings", "provider_name")
@classmethod
def migrate_datasets(cls) -> None:
table_name = "datasets"
provider_column_name = "embedding_model_provider"
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
failed_ids = []
while True:
sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
limit 1000"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql))
current_iter_count = 0
for i in rs:
record_id = str(i.id)
provider_name = str(i.provider_name)
retrieval_model = i.retrieval_model
print(type(retrieval_model))
if record_id in failed_ids:
continue
retrieval_model_changed = False
if retrieval_model:
if (
"reranking_model" in retrieval_model
and "reranking_provider_name" in retrieval_model["reranking_model"]
and retrieval_model["reranking_model"]["reranking_provider_name"]
and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
):
click.echo(
click.style(
f"[{processed_count}] Migrating {table_name} {record_id} "
f"(reranking_provider_name: "
f"{retrieval_model['reranking_model']['reranking_provider_name']})",
fg="white",
)
)
retrieval_model["reranking_model"]["reranking_provider_name"] = (
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
)
retrieval_model_changed = True
click.echo(
click.style(
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
fg="white",
)
)
try:
# update provider name append with "langgenius/{provider_name}/{provider_name}"
params = {"record_id": record_id}
update_retrieval_model_sql = ""
if retrieval_model and retrieval_model_changed:
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
params["retrieval_model"] = json.dumps(retrieval_model)
sql = f"""update {table_name}
set {provider_column_name} =
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
{update_retrieval_model_sql}
where id = :record_id"""
conn.execute(db.text(sql), params)
click.echo(
click.style(
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
fg="green",
)
)
except Exception:
failed_ids.append(record_id)
click.echo(
click.style(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
fg="red",
)
)
logger.exception(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
)
continue
current_iter_count += 1
processed_count += 1
if not current_iter_count:
break
click.echo(
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
)
@classmethod
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
failed_ids = []
while True:
sql = f"""select id, {provider_column_name} as provider_name from {table_name}
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
limit 1000"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql))
current_iter_count = 0
for i in rs:
current_iter_count += 1
processed_count += 1
record_id = str(i.id)
provider_name = str(i.provider_name)
if record_id in failed_ids:
continue
click.echo(
click.style(
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
fg="white",
)
)
try:
# update provider name append with "langgenius/{provider_name}/{provider_name}"
sql = f"""update {table_name}
set {provider_column_name} =
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
where id = :record_id"""
conn.execute(db.text(sql), {"record_id": record_id})
click.echo(
click.style(
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
fg="green",
)
)
except Exception:
failed_ids.append(record_id)
click.echo(
click.style(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
fg="red",
)
)
logger.exception(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
)
continue
if not current_iter_count:
break
click.echo(
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
)