mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
fix: retrieval weights default value
This commit is contained in:
parent
f09b0c7382
commit
c96dc1e70d
|
@ -16,6 +16,8 @@ import type { DataSet } from '@/models/datasets'
|
|||
import {
|
||||
getMultipleRetrievalConfig,
|
||||
} from '@/app/components/workflow/nodes/knowledge-retrieval/utils'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
|
||||
const Icon = (
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
|
@ -41,10 +43,14 @@ const DatasetConfig: FC = () => {
|
|||
|
||||
const hasData = dataSet.length > 0
|
||||
|
||||
const {
|
||||
isValid: isValidRerankModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const onRemove = (id: string) => {
|
||||
const filteredDataSets = dataSet.filter(item => item.id !== id)
|
||||
setDataSet(filteredDataSets)
|
||||
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet)
|
||||
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, isValidRerankModel)
|
||||
setDatasetConfigs({
|
||||
...(datasetConfigs as any),
|
||||
...retrievalConfig,
|
||||
|
|
|
@ -38,7 +38,7 @@ import ConfigContext from '@/context/debug-configuration'
|
|||
import Config from '@/app/components/app/configuration/config'
|
||||
import Debug from '@/app/components/app/configuration/debug'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { fetchAppDetail, updateAppModelConfig } from '@/service/apps'
|
||||
import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config'
|
||||
|
@ -53,7 +53,10 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
|||
import Drawer from '@/app/components/base/drawer'
|
||||
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
|
||||
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import {
|
||||
useModelListAndDefaultModelAndCurrentProviderAndModel,
|
||||
useTextGenerationCurrentProviderAndModelAndModelList,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { fetchCollectionList } from '@/service/tools'
|
||||
import { type Collection } from '@/app/components/tools/types'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
|
@ -217,6 +220,9 @@ const Configuration: FC = () => {
|
|||
const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false)
|
||||
const selectedIds = dataSets.map(item => item.id)
|
||||
const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false)
|
||||
const {
|
||||
isValid: isValidRerankModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
const handleSelect = (data: DataSet[]) => {
|
||||
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
|
||||
hideSelectDataSet()
|
||||
|
@ -263,7 +269,7 @@ const Configuration: FC = () => {
|
|||
reranking_mode: restConfigs.reranking_mode,
|
||||
weights: restConfigs.weights,
|
||||
reranking_enable: restConfigs.reranking_enable,
|
||||
}, newDatasets, dataSets)
|
||||
}, newDatasets, dataSets, isValidRerankModel)
|
||||
|
||||
setDatasetConfigs({
|
||||
...retrievalConfig,
|
||||
|
|
|
@ -180,6 +180,7 @@ export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: Mode
|
|||
defaultModel,
|
||||
currentProvider,
|
||||
currentModel,
|
||||
isValid: !!modelList.find(provider => provider.provider === currentProvider?.provider && provider.models.find(model => model.model === currentModel?.model)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -63,6 +63,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
const {
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
isValid: isValidRerankModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
|
@ -231,7 +232,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
|
||||
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets)
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, isValidRerankModel)
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
|
@ -243,7 +244,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
|| (allExternal && newDatasets.length > 1)
|
||||
)
|
||||
setRerankModelOpen(true)
|
||||
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets])
|
||||
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, isValidRerankModel])
|
||||
|
||||
const filterVar = useCallback((varPayload: Var) => {
|
||||
return varPayload.type === VarType.string
|
||||
|
|
|
@ -92,6 +92,7 @@ export const getMultipleRetrievalConfig = (
|
|||
multipleRetrievalConfig: MultipleRetrievalConfig,
|
||||
selectedDatasets: DataSet[],
|
||||
originalDatasets: DataSet[],
|
||||
isValidRerankModel?: boolean,
|
||||
) => {
|
||||
const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
|
||||
|
||||
|
@ -131,7 +132,7 @@ export const getMultipleRetrievalConfig = (
|
|||
if (allHighQuality && !inconsistentEmbeddingModel && reranking_mode === undefined && allInternal)
|
||||
result.reranking_mode = RerankingModeEnum.WeightedScore
|
||||
|
||||
if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) {
|
||||
if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && !weights) {
|
||||
result.weights = {
|
||||
vector_setting: {
|
||||
vector_weight: allHighQualityVectorSearch
|
||||
|
@ -152,7 +153,10 @@ export const getMultipleRetrievalConfig = (
|
|||
}
|
||||
}
|
||||
|
||||
if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && weights) {
|
||||
if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) {
|
||||
if (!isValidRerankModel)
|
||||
result.reranking_mode = RerankingModeEnum.WeightedScore
|
||||
|
||||
result.weights = {
|
||||
vector_setting: {
|
||||
vector_weight: allHighQualityVectorSearch
|
||||
|
|
Loading…
Reference in New Issue
Block a user