mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
compatible with original provider name
This commit is contained in:
parent
d0c53fabca
commit
b75dce5d0a
|
@ -25,6 +25,7 @@ from models.dataset import Document as DatasetDocument
|
||||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||||
from models.provider import Provider, ProviderModel
|
from models.provider import Provider, ProviderModel
|
||||||
from services.account_service import RegisterService, TenantService
|
from services.account_service import RegisterService, TenantService
|
||||||
|
from services.plugin.data_migration import PluginDataMigration
|
||||||
|
|
||||||
|
|
||||||
@click.command("reset-password", help="Reset the account password.")
|
@click.command("reset-password", help="Reset the account password.")
|
||||||
|
@ -642,6 +643,18 @@ where sites.id is null limit 1000"""
|
||||||
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
|
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
|
@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
|
||||||
|
def migrate_data_for_plugin():
|
||||||
|
"""
|
||||||
|
Migrate data for plugin.
|
||||||
|
"""
|
||||||
|
click.echo(click.style("Starting migrate data for plugin.", fg="white"))
|
||||||
|
|
||||||
|
PluginDataMigration.migrate()
|
||||||
|
|
||||||
|
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
def register_commands(app):
|
def register_commands(app):
|
||||||
app.cli.add_command(reset_password)
|
app.cli.add_command(reset_password)
|
||||||
app.cli.add_command(reset_email)
|
app.cli.add_command(reset_email)
|
||||||
|
@ -652,3 +665,4 @@ def register_commands(app):
|
||||||
app.cli.add_command(create_tenant)
|
app.cli.add_command(create_tenant)
|
||||||
app.cli.add_command(upgrade_db)
|
app.cli.add_command(upgrade_db)
|
||||||
app.cli.add_command(fix_app_site_missing)
|
app.cli.add_command(fix_app_site_missing)
|
||||||
|
app.cli.add_command(migrate_data_for_plugin)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from core.app.app_config.entities import ModelConfigEntity
|
from core.app.app_config.entities import ModelConfigEntity
|
||||||
|
from core.entities import DEFAULT_PLUGIN_ID
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
|
@ -53,7 +54,15 @@ class ModelConfigManager:
|
||||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||||
provider_entities = model_provider_factory.get_providers()
|
provider_entities = model_provider_factory.get_providers()
|
||||||
model_provider_names = [provider.provider for provider in provider_entities]
|
model_provider_names = [provider.provider for provider in provider_entities]
|
||||||
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
if "provider" not in config["model"]:
|
||||||
|
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||||
|
|
||||||
|
if "/" not in config["model"]["provider"]:
|
||||||
|
config["model"]["provider"] = (
|
||||||
|
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if config["model"]["provider"] not in model_provider_names:
|
||||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||||
|
|
||||||
# model.name
|
# model.name
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import Optional
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
|
from core.entities import DEFAULT_PLUGIN_ID
|
||||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||||
from core.entities.provider_entities import (
|
from core.entities.provider_entities import (
|
||||||
CustomConfiguration,
|
CustomConfiguration,
|
||||||
|
@ -1047,6 +1048,9 @@ class ProviderConfigurations(BaseModel):
|
||||||
return list(self.values())
|
return list(self.values())
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
|
if "/" not in key:
|
||||||
|
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||||
|
|
||||||
return self.configurations[key]
|
return self.configurations[key]
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
|
@ -1059,6 +1063,9 @@ class ProviderConfigurations(BaseModel):
|
||||||
return iter(self.configurations.values())
|
return iter(self.configurations.values())
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
|
if "/" not in key:
|
||||||
|
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||||
|
|
||||||
return self.configurations.get(key, default)
|
return self.configurations.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
|
0
api/services/plugin/__init__.py
Normal file
0
api/services/plugin/__init__.py
Normal file
184
api/services/plugin/data_migration.py
Normal file
184
api/services/plugin/data_migration.py
Normal file
|
@ -0,0 +1,184 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from core.entities import DEFAULT_PLUGIN_ID
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginDataMigration:
|
||||||
|
@classmethod
|
||||||
|
def migrate(cls) -> None:
|
||||||
|
cls.migrate_db_records("providers", "provider_name") # large table
|
||||||
|
cls.migrate_db_records("provider_models", "provider_name")
|
||||||
|
cls.migrate_db_records("provider_orders", "provider_name")
|
||||||
|
cls.migrate_db_records("tenant_default_models", "provider_name")
|
||||||
|
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
|
||||||
|
cls.migrate_db_records("provider_model_settings", "provider_name")
|
||||||
|
cls.migrate_db_records("load_balancing_model_configs", "provider_name")
|
||||||
|
cls.migrate_datasets()
|
||||||
|
cls.migrate_db_records("embeddings", "provider_name") # large table
|
||||||
|
cls.migrate_db_records("dataset_collection_bindings", "provider_name")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def migrate_datasets(cls) -> None:
|
||||||
|
table_name = "datasets"
|
||||||
|
provider_column_name = "embedding_model_provider"
|
||||||
|
|
||||||
|
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||||
|
|
||||||
|
processed_count = 0
|
||||||
|
failed_ids = []
|
||||||
|
while True:
|
||||||
|
sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
|
||||||
|
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
|
||||||
|
limit 1000"""
|
||||||
|
with db.engine.begin() as conn:
|
||||||
|
rs = conn.execute(db.text(sql))
|
||||||
|
|
||||||
|
current_iter_count = 0
|
||||||
|
for i in rs:
|
||||||
|
record_id = str(i.id)
|
||||||
|
provider_name = str(i.provider_name)
|
||||||
|
retrieval_model = i.retrieval_model
|
||||||
|
print(type(retrieval_model))
|
||||||
|
|
||||||
|
if record_id in failed_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
retrieval_model_changed = False
|
||||||
|
if retrieval_model:
|
||||||
|
if (
|
||||||
|
"reranking_model" in retrieval_model
|
||||||
|
and "reranking_provider_name" in retrieval_model["reranking_model"]
|
||||||
|
and retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||||
|
and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||||
|
):
|
||||||
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"[{processed_count}] Migrating {table_name} {record_id} "
|
||||||
|
f"(reranking_provider_name: "
|
||||||
|
f"{retrieval_model['reranking_model']['reranking_provider_name']})",
|
||||||
|
fg="white",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
retrieval_model["reranking_model"]["reranking_provider_name"] = (
|
||||||
|
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
|
||||||
|
)
|
||||||
|
retrieval_model_changed = True
|
||||||
|
|
||||||
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
|
||||||
|
fg="white",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# update provider name append with "langgenius/{provider_name}/{provider_name}"
|
||||||
|
params = {"record_id": record_id}
|
||||||
|
update_retrieval_model_sql = ""
|
||||||
|
if retrieval_model and retrieval_model_changed:
|
||||||
|
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
|
||||||
|
params["retrieval_model"] = json.dumps(retrieval_model)
|
||||||
|
|
||||||
|
sql = f"""update {table_name}
|
||||||
|
set {provider_column_name} =
|
||||||
|
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
|
||||||
|
{update_retrieval_model_sql}
|
||||||
|
where id = :record_id"""
|
||||||
|
conn.execute(db.text(sql), params)
|
||||||
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
failed_ids.append(record_id)
|
||||||
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
|
||||||
|
fg="red",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.exception(
|
||||||
|
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_iter_count += 1
|
||||||
|
processed_count += 1
|
||||||
|
|
||||||
|
if not current_iter_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
click.echo(
|
||||||
|
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
|
||||||
|
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||||
|
|
||||||
|
processed_count = 0
|
||||||
|
failed_ids = []
|
||||||
|
while True:
|
||||||
|
sql = f"""select id, {provider_column_name} as provider_name from {table_name}
|
||||||
|
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
|
||||||
|
limit 1000"""
|
||||||
|
with db.engine.begin() as conn:
|
||||||
|
rs = conn.execute(db.text(sql))
|
||||||
|
|
||||||
|
current_iter_count = 0
|
||||||
|
for i in rs:
|
||||||
|
current_iter_count += 1
|
||||||
|
processed_count += 1
|
||||||
|
record_id = str(i.id)
|
||||||
|
provider_name = str(i.provider_name)
|
||||||
|
|
||||||
|
if record_id in failed_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
|
||||||
|
fg="white",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# update provider name append with "langgenius/{provider_name}/{provider_name}"
|
||||||
|
sql = f"""update {table_name}
|
||||||
|
set {provider_column_name} =
|
||||||
|
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
|
||||||
|
where id = :record_id"""
|
||||||
|
conn.execute(db.text(sql), {"record_id": record_id})
|
||||||
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
failed_ids.append(record_id)
|
||||||
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
|
||||||
|
fg="red",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.exception(
|
||||||
|
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not current_iter_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
click.echo(
|
||||||
|
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user