mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat: Add Vanna.AI as a builtin tool (#4878)
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
parent
7133a16511
commit
2d9f55b632
BIN
api/core/tools/provider/builtin/vanna/_assets/icon.png
Normal file
BIN
api/core/tools/provider/builtin/vanna/_assets/icon.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.5 KiB |
119
api/core/tools/provider/builtin/vanna/tools/vanna.py
Normal file
119
api/core/tools/provider/builtin/vanna/tools/vanna.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from vanna.remote import VannaDefault
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class VannaTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key = self.runtime.credentials.get("api_key", None)
|
||||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError("Please input api key")
|
||||
|
||||
model = tool_parameters.get("model", "")
|
||||
if not model:
|
||||
return self.create_text_message("Please input RAG model")
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
|
||||
url = tool_parameters.get("url", "")
|
||||
if not url:
|
||||
return self.create_text_message("Please input URL/Host/DSN")
|
||||
|
||||
db_name = tool_parameters.get("db_name", "")
|
||||
username = tool_parameters.get("username", "")
|
||||
password = tool_parameters.get("password", "")
|
||||
port = tool_parameters.get("port", 0)
|
||||
|
||||
vn = VannaDefault(model=model, api_key=api_key)
|
||||
|
||||
db_type = tool_parameters.get("db_type", "")
|
||||
if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]:
|
||||
if not db_name:
|
||||
return self.create_text_message("Please input database name")
|
||||
if not username:
|
||||
return self.create_text_message("Please input username")
|
||||
if port < 1:
|
||||
return self.create_text_message("Please input port")
|
||||
|
||||
schema_sql = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS"
|
||||
match db_type:
|
||||
case "SQLite":
|
||||
schema_sql = "SELECT type, sql FROM sqlite_master WHERE sql is not null"
|
||||
vn.connect_to_sqlite(url)
|
||||
case "Postgres":
|
||||
vn.connect_to_postgres(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||
case "DuckDB":
|
||||
vn.connect_to_duckdb(url=url)
|
||||
case "SQLServer":
|
||||
vn.connect_to_mssql(url)
|
||||
case "MySQL":
|
||||
vn.connect_to_mysql(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||
case "Oracle":
|
||||
vn.connect_to_oracle(user=username, password=password, dsn=url)
|
||||
case "Hive":
|
||||
vn.connect_to_hive(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||
case "ClickHouse":
|
||||
vn.connect_to_clickhouse(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||
|
||||
enable_training = tool_parameters.get("enable_training", False)
|
||||
reset_training_data = tool_parameters.get("reset_training_data", False)
|
||||
if enable_training:
|
||||
if reset_training_data:
|
||||
existing_training_data = vn.get_training_data()
|
||||
if len(existing_training_data) > 0:
|
||||
for _, training_data in existing_training_data.iterrows():
|
||||
vn.remove_training_data(training_data["id"])
|
||||
|
||||
ddl = tool_parameters.get("ddl", "")
|
||||
question = tool_parameters.get("question", "")
|
||||
sql = tool_parameters.get("sql", "")
|
||||
memos = tool_parameters.get("memos", "")
|
||||
training_metadata = tool_parameters.get("training_metadata", False)
|
||||
|
||||
if training_metadata:
|
||||
if db_type == "SQLite":
|
||||
df_ddl = vn.run_sql(schema_sql)
|
||||
for ddl in df_ddl["sql"].to_list():
|
||||
vn.train(ddl=ddl)
|
||||
else:
|
||||
df_information_schema = vn.run_sql(schema_sql)
|
||||
plan = vn.get_training_plan_generic(df_information_schema)
|
||||
vn.train(plan=plan)
|
||||
|
||||
if ddl:
|
||||
vn.train(ddl=ddl)
|
||||
|
||||
if sql:
|
||||
if question:
|
||||
vn.train(question=question, sql=sql)
|
||||
else:
|
||||
vn.train(sql=sql)
|
||||
if memos:
|
||||
vn.train(documentation=memos)
|
||||
|
||||
generate_chart = tool_parameters.get("generate_chart", True)
|
||||
res = vn.ask(prompt, False, True, generate_chart)
|
||||
|
||||
result = []
|
||||
|
||||
if res is not None:
|
||||
result.append(self.create_text_message(res[0]))
|
||||
if len(res) > 1 and res[1] is not None:
|
||||
result.append(self.create_text_message(res[1].to_markdown()))
|
||||
if len(res) > 2 and res[2] is not None:
|
||||
result.append(
|
||||
self.create_blob_message(blob=res[2].to_image(format="svg"), meta={"mime_type": "image/svg+xml"})
|
||||
)
|
||||
|
||||
return result
|
213
api/core/tools/provider/builtin/vanna/tools/vanna.yaml
Normal file
213
api/core/tools/provider/builtin/vanna/tools/vanna.yaml
Normal file
|
@ -0,0 +1,213 @@
|
|||
identity:
|
||||
name: vanna
|
||||
author: QCTC
|
||||
label:
|
||||
en_US: Vanna.AI
|
||||
zh_Hans: Vanna.AI
|
||||
description:
|
||||
human:
|
||||
en_US: The fastest way to get actionable insights from your database just by asking questions.
|
||||
zh_Hans: 一个基于大模型和RAG的Text2SQL工具。
|
||||
llm: A tool for converting text to SQL.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: used for generating SQL
|
||||
zh_Hans: 用于生成SQL
|
||||
llm_description: key words for generating SQL
|
||||
form: llm
|
||||
- name: model
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: RAG Model
|
||||
zh_Hans: RAG模型
|
||||
human_description:
|
||||
en_US: RAG Model for your database DDL
|
||||
zh_Hans: 存储数据库训练数据的RAG模型
|
||||
llm_description: RAG Model for generating SQL
|
||||
form: form
|
||||
- name: db_type
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: SQLite
|
||||
label:
|
||||
en_US: SQLite
|
||||
zh_Hans: SQLite
|
||||
- value: Postgres
|
||||
label:
|
||||
en_US: Postgres
|
||||
zh_Hans: Postgres
|
||||
- value: DuckDB
|
||||
label:
|
||||
en_US: DuckDB
|
||||
zh_Hans: DuckDB
|
||||
- value: SQLServer
|
||||
label:
|
||||
en_US: Microsoft SQL Server
|
||||
zh_Hans: 微软 SQL Server
|
||||
- value: MySQL
|
||||
label:
|
||||
en_US: MySQL
|
||||
zh_Hans: MySQL
|
||||
- value: Oracle
|
||||
label:
|
||||
en_US: Oracle
|
||||
zh_Hans: Oracle
|
||||
- value: Hive
|
||||
label:
|
||||
en_US: Hive
|
||||
zh_Hans: Hive
|
||||
- value: ClickHouse
|
||||
label:
|
||||
en_US: ClickHouse
|
||||
zh_Hans: ClickHouse
|
||||
default: SQLite
|
||||
label:
|
||||
en_US: DB Type
|
||||
zh_Hans: 数据库类型
|
||||
human_description:
|
||||
en_US: Database type.
|
||||
zh_Hans: 选择要链接的数据库类型。
|
||||
form: form
|
||||
- name: url
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: URL/Host/DSN
|
||||
zh_Hans: URL/Host/DSN
|
||||
human_description:
|
||||
en_US: Please input depending on DB type, visit https://vanna.ai/docs/ for more specification
|
||||
zh_Hans: 请根据数据库类型,填入对应值,详情参考https://vanna.ai/docs/
|
||||
form: form
|
||||
- name: db_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: DB name
|
||||
zh_Hans: 数据库名
|
||||
human_description:
|
||||
en_US: Database name
|
||||
zh_Hans: 数据库名
|
||||
form: form
|
||||
- name: username
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Username
|
||||
zh_Hans: 用户名
|
||||
human_description:
|
||||
en_US: Username
|
||||
zh_Hans: 用户名
|
||||
form: form
|
||||
- name: password
|
||||
type: secret-input
|
||||
required: false
|
||||
label:
|
||||
en_US: Password
|
||||
zh_Hans: 密码
|
||||
human_description:
|
||||
en_US: Password
|
||||
zh_Hans: 密码
|
||||
form: form
|
||||
- name: port
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 端口
|
||||
human_description:
|
||||
en_US: Port
|
||||
zh_Hans: 端口
|
||||
form: form
|
||||
- name: ddl
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Training DDL
|
||||
zh_Hans: 训练DDL
|
||||
human_description:
|
||||
en_US: DDL statements for training data
|
||||
zh_Hans: 用于训练RAG Model的建表语句
|
||||
form: form
|
||||
- name: question
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Training Question
|
||||
zh_Hans: 训练问题
|
||||
human_description:
|
||||
en_US: Question-SQL Pairs
|
||||
zh_Hans: Question-SQL中的问题
|
||||
form: form
|
||||
- name: sql
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Training SQL
|
||||
zh_Hans: 训练SQL
|
||||
human_description:
|
||||
en_US: SQL queries to your training data
|
||||
zh_Hans: 用于训练RAG Model的SQL语句
|
||||
form: form
|
||||
- name: memos
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Training Memos
|
||||
zh_Hans: 训练说明
|
||||
human_description:
|
||||
en_US: Sometimes you may want to add documentation about your business terminology or definitions
|
||||
zh_Hans: 添加更多关于数据库的业务说明
|
||||
form: form
|
||||
- name: enable_training
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Training Data
|
||||
zh_Hans: 训练数据
|
||||
human_description:
|
||||
en_US: You only need to train once. Do not train again unless you want to add more training data
|
||||
zh_Hans: 训练数据无更新时,训练一次即可
|
||||
form: form
|
||||
- name: reset_training_data
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Reset Training Data
|
||||
zh_Hans: 重置训练数据
|
||||
human_description:
|
||||
en_US: Remove all training data in the current RAG Model
|
||||
zh_Hans: 删除当前RAG Model中的所有训练数据
|
||||
form: form
|
||||
- name: training_metadata
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Training Metadata
|
||||
zh_Hans: 训练元数据
|
||||
human_description:
|
||||
en_US: If enabled, it will attempt to train on the metadata of that database
|
||||
zh_Hans: 是否自动从数据库获取元数据来训练
|
||||
form: form
|
||||
- name: generate_chart
|
||||
type: boolean
|
||||
required: false
|
||||
default: True
|
||||
label:
|
||||
en_US: Generate Charts
|
||||
zh_Hans: 生成图表
|
||||
human_description:
|
||||
en_US: Generate Charts
|
||||
zh_Hans: 是否生成图表
|
||||
form: form
|
25
api/core/tools/provider/builtin/vanna/vanna.py
Normal file
25
api/core/tools/provider/builtin/vanna/vanna.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class VannaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
VannaTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"model": "chinook",
|
||||
"db_type": "SQLite",
|
||||
"url": "https://vanna.ai/Chinook.sqlite",
|
||||
"query": "What are the top 10 customers by sales?"
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
25
api/core/tools/provider/builtin/vanna/vanna.yaml
Normal file
25
api/core/tools/provider/builtin/vanna/vanna.yaml
Normal file
|
@ -0,0 +1,25 @@
|
|||
identity:
|
||||
author: QCTC
|
||||
name: vanna
|
||||
label:
|
||||
en_US: Vanna.AI
|
||||
zh_Hans: Vanna.AI
|
||||
description:
|
||||
en_US: The fastest way to get actionable insights from your database just by asking questions.
|
||||
zh_Hans: 一个基于大模型和RAG的Text2SQL工具。
|
||||
icon: icon.png
|
||||
credentials_for_provider:
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API key
|
||||
zh_Hans: API key
|
||||
placeholder:
|
||||
en_US: Please input your API key
|
||||
zh_Hans: 请输入你的 API key
|
||||
pt_BR: Please input your API key
|
||||
help:
|
||||
en_US: Get your API key from Vanna.AI
|
||||
zh_Hans: 从 Vanna.AI 获取你的 API key
|
||||
url: https://vanna.ai/account/profile
|
|
@ -82,3 +82,4 @@ firecrawl-py==0.0.5
|
|||
oss2==2.18.5
|
||||
pgvector==0.2.5
|
||||
google-cloud-aiplatform==1.49.0
|
||||
vanna[postgres,mysql,clickhouse,duckdb]==0.5.5
|
||||
|
|
Loading…
Reference in New Issue
Block a user