mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
from typing import Dict
|
|
|
|
from langchain.tools import BaseTool
|
|
from llama_index.indices.base import BaseGPTIndex
|
|
from llama_index.langchain_helpers.agents import IndexToolConfig
|
|
from pydantic import Field
|
|
|
|
from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
|
|
|
|
|
|
class EnhanceLlamaIndexTool(BaseTool):
|
|
"""Tool for querying a LlamaIndex."""
|
|
|
|
# NOTE: name/description still needs to be set
|
|
index: BaseGPTIndex
|
|
query_kwargs: Dict = Field(default_factory=dict)
|
|
return_sources: bool = False
|
|
callback_handler: IndexToolCallbackHandler
|
|
|
|
@classmethod
|
|
def from_tool_config(cls, tool_config: IndexToolConfig,
|
|
callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
|
|
"""Create a tool from a tool config."""
|
|
return_sources = tool_config.tool_kwargs.pop("return_sources", False)
|
|
return cls(
|
|
index=tool_config.index,
|
|
callback_handler=callback_handler,
|
|
name=tool_config.name,
|
|
description=tool_config.description,
|
|
return_sources=return_sources,
|
|
query_kwargs=tool_config.index_query_kwargs,
|
|
**tool_config.tool_kwargs,
|
|
)
|
|
|
|
def _run(self, tool_input: str) -> str:
|
|
response = self.index.query(tool_input, **self.query_kwargs)
|
|
self.callback_handler.on_tool_end(response)
|
|
return str(response)
|
|
|
|
async def _arun(self, tool_input: str) -> str:
|
|
response = await self.index.aquery(tool_input, **self.query_kwargs)
|
|
self.callback_handler.on_tool_end(response)
|
|
return str(response)
|