dify/api/core/chain/main_chain_builder.py
John Wang 3241e4015b
feat: upgrade langchain (#430)
Co-authored-by: jyong <718720800@qq.com>
2023-06-25 16:49:14 +08:00

111 lines
4.0 KiB
Python

from typing import Optional, List, cast
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.conversation_message_task import ConversationMessageTask
from extensions.ext_database import db
from models.dataset import Dataset
class MainChainBuilder:
@classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
rest_tokens: int,
conversation_message_task: ConversationMessageTask):
first_input_key = "input"
final_output_key = "output"
chains = []
chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
# agent mode
tool_chains, chains_output_key = cls.get_agent_chains(
tenant_id=tenant_id,
agent_mode=agent_mode,
rest_tokens=rest_tokens,
memory=memory,
conversation_message_task=conversation_message_task
)
chains += tool_chains
if chains_output_key:
final_output_key = chains_output_key
if len(chains) == 0:
return None
for chain in chains:
chain = cast(Chain, chain)
chain.callbacks.append(chain_callback_handler)
# build main chain
overall_chain = SequentialChain(
chains=chains,
input_variables=[first_input_key],
output_variables=[final_output_key],
memory=memory, # only for use the memory prompt input key
)
return overall_chain
@classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask):
# agent mode
chains = []
if agent_mode and agent_mode.get('enabled'):
tools = agent_mode.get('tools', [])
pre_fixed_chains = []
# agent_tools = []
datasets = []
for tool in tools:
tool_type = list(tool.keys())[0]
tool_config = list(tool.values())[0]
if tool_type == 'sensitive-word-avoidance':
chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
if chain:
pre_fixed_chains.append(chain)
elif tool_type == "dataset":
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == tool_config.get("id")
).first()
if dataset:
datasets.append(dataset)
# add pre-fixed chains
chains += pre_fixed_chains
if len(datasets) > 0:
# tool to chain
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
tenant_id=tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chains.append(multi_dataset_router_chain)
final_output_key = cls.get_chains_output_key(chains)
return chains, final_output_key
@classmethod
def get_chains_output_key(cls, chains: List[Chain]):
if len(chains) > 0:
return chains[-1].output_keys[0]
return None