mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
149 lines
6.5 KiB
Python
149 lines
6.5 KiB
Python
import enum
|
|
import logging
|
|
from typing import Optional, Union
|
|
|
|
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
|
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
|
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
|
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
|
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
|
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
|
from core.entities.application_entities import ModelConfigEntity
|
|
from core.entities.message_entities import prompt_messages_to_lc_messages
|
|
from core.helper import moderation
|
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
|
from core.model_runtime.errors.invoke import InvokeError
|
|
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
|
from langchain.agents import AgentExecutor as LCAgentExecutor
|
|
from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent
|
|
from langchain.callbacks.manager import Callbacks
|
|
from langchain.tools import BaseTool
|
|
from pydantic import BaseModel, Extra
|
|
|
|
|
|
class PlanningStrategy(str, enum.Enum):
|
|
ROUTER = 'router'
|
|
REACT_ROUTER = 'react_router'
|
|
REACT = 'react'
|
|
FUNCTION_CALL = 'function_call'
|
|
|
|
|
|
class AgentConfiguration(BaseModel):
|
|
strategy: PlanningStrategy
|
|
model_config: ModelConfigEntity
|
|
tools: list[BaseTool]
|
|
summary_model_config: Optional[ModelConfigEntity] = None
|
|
memory: Optional[TokenBufferMemory] = None
|
|
callbacks: Callbacks = None
|
|
max_iterations: int = 6
|
|
max_execution_time: Optional[float] = None
|
|
early_stopping_method: str = "generate"
|
|
agent_llm_callback: Optional[AgentLLMCallback] = None
|
|
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
class AgentExecuteResult(BaseModel):
|
|
strategy: PlanningStrategy
|
|
output: Optional[str]
|
|
configuration: AgentConfiguration
|
|
|
|
|
|
class AgentExecutor:
|
|
def __init__(self, configuration: AgentConfiguration):
|
|
self.configuration = configuration
|
|
self.agent = self._init_agent()
|
|
|
|
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
|
if self.configuration.strategy == PlanningStrategy.REACT:
|
|
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
|
model_config=self.configuration.model_config,
|
|
tools=self.configuration.tools,
|
|
output_parser=StructuredChatOutputParser(),
|
|
summary_model_config=self.configuration.summary_model_config
|
|
if self.configuration.summary_model_config else None,
|
|
agent_llm_callback=self.configuration.agent_llm_callback,
|
|
verbose=True
|
|
)
|
|
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
|
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
|
model_config=self.configuration.model_config,
|
|
tools=self.configuration.tools,
|
|
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
|
if self.configuration.memory else None, # used for read chat histories memory
|
|
summary_model_config=self.configuration.summary_model_config
|
|
if self.configuration.summary_model_config else None,
|
|
agent_llm_callback=self.configuration.agent_llm_callback,
|
|
verbose=True
|
|
)
|
|
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
|
self.configuration.tools = [t for t in self.configuration.tools
|
|
if isinstance(t, DatasetRetrieverTool)
|
|
or isinstance(t, DatasetMultiRetrieverTool)]
|
|
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
|
model_config=self.configuration.model_config,
|
|
tools=self.configuration.tools,
|
|
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
|
if self.configuration.memory else None,
|
|
verbose=True
|
|
)
|
|
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
|
|
self.configuration.tools = [t for t in self.configuration.tools
|
|
if isinstance(t, DatasetRetrieverTool)
|
|
or isinstance(t, DatasetMultiRetrieverTool)]
|
|
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
|
model_config=self.configuration.model_config,
|
|
tools=self.configuration.tools,
|
|
output_parser=StructuredChatOutputParser(),
|
|
verbose=True
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
|
|
|
|
return agent
|
|
|
|
def should_use_agent(self, query: str) -> bool:
|
|
return self.agent.should_use_agent(query)
|
|
|
|
def run(self, query: str) -> AgentExecuteResult:
|
|
moderation_result = moderation.check_moderation(
|
|
self.configuration.model_config,
|
|
query
|
|
)
|
|
|
|
if moderation_result:
|
|
return AgentExecuteResult(
|
|
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
|
|
strategy=self.configuration.strategy,
|
|
configuration=self.configuration
|
|
)
|
|
|
|
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
|
agent=self.agent,
|
|
tools=self.configuration.tools,
|
|
max_iterations=self.configuration.max_iterations,
|
|
max_execution_time=self.configuration.max_execution_time,
|
|
early_stopping_method=self.configuration.early_stopping_method,
|
|
callbacks=self.configuration.callbacks
|
|
)
|
|
|
|
try:
|
|
output = agent_executor.run(input=query)
|
|
except InvokeError as ex:
|
|
raise ex
|
|
except Exception as ex:
|
|
logging.exception("agent_executor run failed")
|
|
output = None
|
|
|
|
return AgentExecuteResult(
|
|
output=output,
|
|
strategy=self.configuration.strategy,
|
|
configuration=self.configuration
|
|
)
|