diff --git a/api/commands.py b/api/commands.py index 932102db97..fff77430a9 100644 --- a/api/commands.py +++ b/api/commands.py @@ -329,16 +329,23 @@ def create_qdrant_indexes(): model_name=dataset.embedding_model ) except Exception: - provider = Provider( - id='provider_id', - tenant_id=dataset.tenant_id, - provider_name='openai', - provider_type=ProviderType.CUSTOM.value, - encrypted_config=json.dumps({'openai_api_key': 'TEST'}), - is_valid=True, - ) - model_provider = OpenAIProvider(provider=provider) - embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider) + try: + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id + ) + dataset.embedding_model = embedding_model.name + dataset.embedding_model_provider = embedding_model.model_provider.provider_name + except Exception: + provider = Provider( + id='provider_id', + tenant_id=dataset.tenant_id, + provider_name='openai', + provider_type=ProviderType.SYSTEM.value, + encrypted_config=json.dumps({'openai_api_key': 'TEST'}), + is_valid=True, + ) + model_provider = OpenAIProvider(provider=provider) + embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider) embeddings = CacheEmbedding(embedding_model) from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig