fix: retrieval weights default value

This commit is contained in:
StyleZhang 2024-10-21 17:54:14 +08:00
parent f09b0c7382
commit c96dc1e70d
5 changed files with 26 additions and 8 deletions

View File

@ -16,6 +16,8 @@ import type { DataSet } from '@/models/datasets'
import {
getMultipleRetrievalConfig,
} from '@/app/components/workflow/nodes/knowledge-retrieval/utils'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
const Icon = (
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
@ -41,10 +43,14 @@ const DatasetConfig: FC = () => {
const hasData = dataSet.length > 0
const {
isValid: isValidRerankModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const onRemove = (id: string) => {
const filteredDataSets = dataSet.filter(item => item.id !== id)
setDataSet(filteredDataSets)
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet)
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, isValidRerankModel)
setDatasetConfigs({
...(datasetConfigs as any),
...retrievalConfig,

View File

@ -38,7 +38,7 @@ import ConfigContext from '@/context/debug-configuration'
import Config from '@/app/components/app/configuration/config'
import Debug from '@/app/components/app/configuration/debug'
import Confirm from '@/app/components/base/confirm'
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ToastContext } from '@/app/components/base/toast'
import { fetchAppDetail, updateAppModelConfig } from '@/service/apps'
import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config'
@ -53,7 +53,10 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import Drawer from '@/app/components/base/drawer'
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
import {
useModelListAndDefaultModelAndCurrentProviderAndModel,
useTextGenerationCurrentProviderAndModelAndModelList,
} from '@/app/components/header/account-setting/model-provider-page/hooks'
import { fetchCollectionList } from '@/service/tools'
import { type Collection } from '@/app/components/tools/types'
import { useStore as useAppStore } from '@/app/components/app/store'
@ -217,6 +220,9 @@ const Configuration: FC = () => {
const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false)
const selectedIds = dataSets.map(item => item.id)
const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false)
const {
isValid: isValidRerankModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const handleSelect = (data: DataSet[]) => {
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
hideSelectDataSet()
@ -263,7 +269,7 @@ const Configuration: FC = () => {
reranking_mode: restConfigs.reranking_mode,
weights: restConfigs.weights,
reranking_enable: restConfigs.reranking_enable,
}, newDatasets, dataSets)
}, newDatasets, dataSets, isValidRerankModel)
setDatasetConfigs({
...retrievalConfig,

View File

@ -180,6 +180,7 @@ export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: Mode
defaultModel,
currentProvider,
currentModel,
isValid: !!modelList.find(provider => provider.provider === currentProvider?.provider && provider.models.find(model => model.model === currentModel?.model)),
}
}

View File

@ -63,6 +63,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
isValid: isValidRerankModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
@ -231,7 +232,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets)
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, isValidRerankModel)
}
})
setInputs(newInputs)
@ -243,7 +244,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|| (allExternal && newDatasets.length > 1)
)
setRerankModelOpen(true)
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets])
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, isValidRerankModel])
const filterVar = useCallback((varPayload: Var) => {
return varPayload.type === VarType.string

View File

@ -92,6 +92,7 @@ export const getMultipleRetrievalConfig = (
multipleRetrievalConfig: MultipleRetrievalConfig,
selectedDatasets: DataSet[],
originalDatasets: DataSet[],
isValidRerankModel?: boolean,
) => {
const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
@ -131,7 +132,7 @@ export const getMultipleRetrievalConfig = (
if (allHighQuality && !inconsistentEmbeddingModel && reranking_mode === undefined && allInternal)
result.reranking_mode = RerankingModeEnum.WeightedScore
if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) {
if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && !weights) {
result.weights = {
vector_setting: {
vector_weight: allHighQualityVectorSearch
@ -152,7 +153,10 @@ export const getMultipleRetrievalConfig = (
}
}
if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && weights) {
if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) {
if (!isValidRerankModel)
result.reranking_mode = RerankingModeEnum.WeightedScore
result.weights = {
vector_setting: {
vector_weight: allHighQualityVectorSearch