mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
feat: Add Vanna.AI as a builtin tool (#4878)
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
parent
7133a16511
commit
2d9f55b632
BIN
api/core/tools/provider/builtin/vanna/_assets/icon.png
Normal file
BIN
api/core/tools/provider/builtin/vanna/_assets/icon.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.5 KiB |
119
api/core/tools/provider/builtin/vanna/tools/vanna.py
Normal file
119
api/core/tools/provider/builtin/vanna/tools/vanna.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from vanna.remote import VannaDefault
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
|
||||||
|
class VannaTool(BuiltinTool):
|
||||||
|
def _invoke(
|
||||||
|
self, user_id: str, tool_parameters: dict[str, Any]
|
||||||
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
|
"""
|
||||||
|
invoke tools
|
||||||
|
"""
|
||||||
|
api_key = self.runtime.credentials.get("api_key", None)
|
||||||
|
if not api_key:
|
||||||
|
raise ToolProviderCredentialValidationError("Please input api key")
|
||||||
|
|
||||||
|
model = tool_parameters.get("model", "")
|
||||||
|
if not model:
|
||||||
|
return self.create_text_message("Please input RAG model")
|
||||||
|
|
||||||
|
prompt = tool_parameters.get("prompt", "")
|
||||||
|
if not prompt:
|
||||||
|
return self.create_text_message("Please input prompt")
|
||||||
|
|
||||||
|
url = tool_parameters.get("url", "")
|
||||||
|
if not url:
|
||||||
|
return self.create_text_message("Please input URL/Host/DSN")
|
||||||
|
|
||||||
|
db_name = tool_parameters.get("db_name", "")
|
||||||
|
username = tool_parameters.get("username", "")
|
||||||
|
password = tool_parameters.get("password", "")
|
||||||
|
port = tool_parameters.get("port", 0)
|
||||||
|
|
||||||
|
vn = VannaDefault(model=model, api_key=api_key)
|
||||||
|
|
||||||
|
db_type = tool_parameters.get("db_type", "")
|
||||||
|
if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]:
|
||||||
|
if not db_name:
|
||||||
|
return self.create_text_message("Please input database name")
|
||||||
|
if not username:
|
||||||
|
return self.create_text_message("Please input username")
|
||||||
|
if port < 1:
|
||||||
|
return self.create_text_message("Please input port")
|
||||||
|
|
||||||
|
schema_sql = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS"
|
||||||
|
match db_type:
|
||||||
|
case "SQLite":
|
||||||
|
schema_sql = "SELECT type, sql FROM sqlite_master WHERE sql is not null"
|
||||||
|
vn.connect_to_sqlite(url)
|
||||||
|
case "Postgres":
|
||||||
|
vn.connect_to_postgres(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||||
|
case "DuckDB":
|
||||||
|
vn.connect_to_duckdb(url=url)
|
||||||
|
case "SQLServer":
|
||||||
|
vn.connect_to_mssql(url)
|
||||||
|
case "MySQL":
|
||||||
|
vn.connect_to_mysql(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||||
|
case "Oracle":
|
||||||
|
vn.connect_to_oracle(user=username, password=password, dsn=url)
|
||||||
|
case "Hive":
|
||||||
|
vn.connect_to_hive(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||||
|
case "ClickHouse":
|
||||||
|
vn.connect_to_clickhouse(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||||
|
|
||||||
|
enable_training = tool_parameters.get("enable_training", False)
|
||||||
|
reset_training_data = tool_parameters.get("reset_training_data", False)
|
||||||
|
if enable_training:
|
||||||
|
if reset_training_data:
|
||||||
|
existing_training_data = vn.get_training_data()
|
||||||
|
if len(existing_training_data) > 0:
|
||||||
|
for _, training_data in existing_training_data.iterrows():
|
||||||
|
vn.remove_training_data(training_data["id"])
|
||||||
|
|
||||||
|
ddl = tool_parameters.get("ddl", "")
|
||||||
|
question = tool_parameters.get("question", "")
|
||||||
|
sql = tool_parameters.get("sql", "")
|
||||||
|
memos = tool_parameters.get("memos", "")
|
||||||
|
training_metadata = tool_parameters.get("training_metadata", False)
|
||||||
|
|
||||||
|
if training_metadata:
|
||||||
|
if db_type == "SQLite":
|
||||||
|
df_ddl = vn.run_sql(schema_sql)
|
||||||
|
for ddl in df_ddl["sql"].to_list():
|
||||||
|
vn.train(ddl=ddl)
|
||||||
|
else:
|
||||||
|
df_information_schema = vn.run_sql(schema_sql)
|
||||||
|
plan = vn.get_training_plan_generic(df_information_schema)
|
||||||
|
vn.train(plan=plan)
|
||||||
|
|
||||||
|
if ddl:
|
||||||
|
vn.train(ddl=ddl)
|
||||||
|
|
||||||
|
if sql:
|
||||||
|
if question:
|
||||||
|
vn.train(question=question, sql=sql)
|
||||||
|
else:
|
||||||
|
vn.train(sql=sql)
|
||||||
|
if memos:
|
||||||
|
vn.train(documentation=memos)
|
||||||
|
|
||||||
|
generate_chart = tool_parameters.get("generate_chart", True)
|
||||||
|
res = vn.ask(prompt, False, True, generate_chart)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
if res is not None:
|
||||||
|
result.append(self.create_text_message(res[0]))
|
||||||
|
if len(res) > 1 and res[1] is not None:
|
||||||
|
result.append(self.create_text_message(res[1].to_markdown()))
|
||||||
|
if len(res) > 2 and res[2] is not None:
|
||||||
|
result.append(
|
||||||
|
self.create_blob_message(blob=res[2].to_image(format="svg"), meta={"mime_type": "image/svg+xml"})
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
213
api/core/tools/provider/builtin/vanna/tools/vanna.yaml
Normal file
213
api/core/tools/provider/builtin/vanna/tools/vanna.yaml
Normal file
|
@ -0,0 +1,213 @@
|
||||||
|
identity:
|
||||||
|
name: vanna
|
||||||
|
author: QCTC
|
||||||
|
label:
|
||||||
|
en_US: Vanna.AI
|
||||||
|
zh_Hans: Vanna.AI
|
||||||
|
description:
|
||||||
|
human:
|
||||||
|
en_US: The fastest way to get actionable insights from your database just by asking questions.
|
||||||
|
zh_Hans: 一个基于大模型和RAG的Text2SQL工具。
|
||||||
|
llm: A tool for converting text to SQL.
|
||||||
|
parameters:
|
||||||
|
- name: prompt
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: Prompt
|
||||||
|
zh_Hans: 提示词
|
||||||
|
pt_BR: Prompt
|
||||||
|
human_description:
|
||||||
|
en_US: used for generating SQL
|
||||||
|
zh_Hans: 用于生成SQL
|
||||||
|
llm_description: key words for generating SQL
|
||||||
|
form: llm
|
||||||
|
- name: model
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: RAG Model
|
||||||
|
zh_Hans: RAG模型
|
||||||
|
human_description:
|
||||||
|
en_US: RAG Model for your database DDL
|
||||||
|
zh_Hans: 存储数据库训练数据的RAG模型
|
||||||
|
llm_description: RAG Model for generating SQL
|
||||||
|
form: form
|
||||||
|
- name: db_type
|
||||||
|
type: select
|
||||||
|
required: true
|
||||||
|
options:
|
||||||
|
- value: SQLite
|
||||||
|
label:
|
||||||
|
en_US: SQLite
|
||||||
|
zh_Hans: SQLite
|
||||||
|
- value: Postgres
|
||||||
|
label:
|
||||||
|
en_US: Postgres
|
||||||
|
zh_Hans: Postgres
|
||||||
|
- value: DuckDB
|
||||||
|
label:
|
||||||
|
en_US: DuckDB
|
||||||
|
zh_Hans: DuckDB
|
||||||
|
- value: SQLServer
|
||||||
|
label:
|
||||||
|
en_US: Microsoft SQL Server
|
||||||
|
zh_Hans: 微软 SQL Server
|
||||||
|
- value: MySQL
|
||||||
|
label:
|
||||||
|
en_US: MySQL
|
||||||
|
zh_Hans: MySQL
|
||||||
|
- value: Oracle
|
||||||
|
label:
|
||||||
|
en_US: Oracle
|
||||||
|
zh_Hans: Oracle
|
||||||
|
- value: Hive
|
||||||
|
label:
|
||||||
|
en_US: Hive
|
||||||
|
zh_Hans: Hive
|
||||||
|
- value: ClickHouse
|
||||||
|
label:
|
||||||
|
en_US: ClickHouse
|
||||||
|
zh_Hans: ClickHouse
|
||||||
|
default: SQLite
|
||||||
|
label:
|
||||||
|
en_US: DB Type
|
||||||
|
zh_Hans: 数据库类型
|
||||||
|
human_description:
|
||||||
|
en_US: Database type.
|
||||||
|
zh_Hans: 选择要链接的数据库类型。
|
||||||
|
form: form
|
||||||
|
- name: url
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: URL/Host/DSN
|
||||||
|
zh_Hans: URL/Host/DSN
|
||||||
|
human_description:
|
||||||
|
en_US: Please input depending on DB type, visit https://vanna.ai/docs/ for more specification
|
||||||
|
zh_Hans: 请根据数据库类型,填入对应值,详情参考https://vanna.ai/docs/
|
||||||
|
form: form
|
||||||
|
- name: db_name
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: DB name
|
||||||
|
zh_Hans: 数据库名
|
||||||
|
human_description:
|
||||||
|
en_US: Database name
|
||||||
|
zh_Hans: 数据库名
|
||||||
|
form: form
|
||||||
|
- name: username
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Username
|
||||||
|
zh_Hans: 用户名
|
||||||
|
human_description:
|
||||||
|
en_US: Username
|
||||||
|
zh_Hans: 用户名
|
||||||
|
form: form
|
||||||
|
- name: password
|
||||||
|
type: secret-input
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Password
|
||||||
|
zh_Hans: 密码
|
||||||
|
human_description:
|
||||||
|
en_US: Password
|
||||||
|
zh_Hans: 密码
|
||||||
|
form: form
|
||||||
|
- name: port
|
||||||
|
type: number
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Port
|
||||||
|
zh_Hans: 端口
|
||||||
|
human_description:
|
||||||
|
en_US: Port
|
||||||
|
zh_Hans: 端口
|
||||||
|
form: form
|
||||||
|
- name: ddl
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Training DDL
|
||||||
|
zh_Hans: 训练DDL
|
||||||
|
human_description:
|
||||||
|
en_US: DDL statements for training data
|
||||||
|
zh_Hans: 用于训练RAG Model的建表语句
|
||||||
|
form: form
|
||||||
|
- name: question
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Training Question
|
||||||
|
zh_Hans: 训练问题
|
||||||
|
human_description:
|
||||||
|
en_US: Question-SQL Pairs
|
||||||
|
zh_Hans: Question-SQL中的问题
|
||||||
|
form: form
|
||||||
|
- name: sql
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Training SQL
|
||||||
|
zh_Hans: 训练SQL
|
||||||
|
human_description:
|
||||||
|
en_US: SQL queries to your training data
|
||||||
|
zh_Hans: 用于训练RAG Model的SQL语句
|
||||||
|
form: form
|
||||||
|
- name: memos
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Training Memos
|
||||||
|
zh_Hans: 训练说明
|
||||||
|
human_description:
|
||||||
|
en_US: Sometimes you may want to add documentation about your business terminology or definitions
|
||||||
|
zh_Hans: 添加更多关于数据库的业务说明
|
||||||
|
form: form
|
||||||
|
- name: enable_training
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
label:
|
||||||
|
en_US: Training Data
|
||||||
|
zh_Hans: 训练数据
|
||||||
|
human_description:
|
||||||
|
en_US: You only need to train once. Do not train again unless you want to add more training data
|
||||||
|
zh_Hans: 训练数据无更新时,训练一次即可
|
||||||
|
form: form
|
||||||
|
- name: reset_training_data
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
label:
|
||||||
|
en_US: Reset Training Data
|
||||||
|
zh_Hans: 重置训练数据
|
||||||
|
human_description:
|
||||||
|
en_US: Remove all training data in the current RAG Model
|
||||||
|
zh_Hans: 删除当前RAG Model中的所有训练数据
|
||||||
|
form: form
|
||||||
|
- name: training_metadata
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
label:
|
||||||
|
en_US: Training Metadata
|
||||||
|
zh_Hans: 训练元数据
|
||||||
|
human_description:
|
||||||
|
en_US: If enabled, it will attempt to train on the metadata of that database
|
||||||
|
zh_Hans: 是否自动从数据库获取元数据来训练
|
||||||
|
form: form
|
||||||
|
- name: generate_chart
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: True
|
||||||
|
label:
|
||||||
|
en_US: Generate Charts
|
||||||
|
zh_Hans: 生成图表
|
||||||
|
human_description:
|
||||||
|
en_US: Generate Charts
|
||||||
|
zh_Hans: 是否生成图表
|
||||||
|
form: form
|
25
api/core/tools/provider/builtin/vanna/vanna.py
Normal file
25
api/core/tools/provider/builtin/vanna/vanna.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
|
||||||
|
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||||
|
|
||||||
|
|
||||||
|
class VannaProvider(BuiltinToolProviderController):
|
||||||
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
|
try:
|
||||||
|
VannaTool().fork_tool_runtime(
|
||||||
|
runtime={
|
||||||
|
"credentials": credentials,
|
||||||
|
}
|
||||||
|
).invoke(
|
||||||
|
user_id='',
|
||||||
|
tool_parameters={
|
||||||
|
"model": "chinook",
|
||||||
|
"db_type": "SQLite",
|
||||||
|
"url": "https://vanna.ai/Chinook.sqlite",
|
||||||
|
"query": "What are the top 10 customers by sales?"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolProviderCredentialValidationError(str(e))
|
25
api/core/tools/provider/builtin/vanna/vanna.yaml
Normal file
25
api/core/tools/provider/builtin/vanna/vanna.yaml
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
identity:
|
||||||
|
author: QCTC
|
||||||
|
name: vanna
|
||||||
|
label:
|
||||||
|
en_US: Vanna.AI
|
||||||
|
zh_Hans: Vanna.AI
|
||||||
|
description:
|
||||||
|
en_US: The fastest way to get actionable insights from your database just by asking questions.
|
||||||
|
zh_Hans: 一个基于大模型和RAG的Text2SQL工具。
|
||||||
|
icon: icon.png
|
||||||
|
credentials_for_provider:
|
||||||
|
api_key:
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: API key
|
||||||
|
zh_Hans: API key
|
||||||
|
placeholder:
|
||||||
|
en_US: Please input your API key
|
||||||
|
zh_Hans: 请输入你的 API key
|
||||||
|
pt_BR: Please input your API key
|
||||||
|
help:
|
||||||
|
en_US: Get your API key from Vanna.AI
|
||||||
|
zh_Hans: 从 Vanna.AI 获取你的 API key
|
||||||
|
url: https://vanna.ai/account/profile
|
|
@ -82,3 +82,4 @@ firecrawl-py==0.0.5
|
||||||
oss2==2.18.5
|
oss2==2.18.5
|
||||||
pgvector==0.2.5
|
pgvector==0.2.5
|
||||||
google-cloud-aiplatform==1.49.0
|
google-cloud-aiplatform==1.49.0
|
||||||
|
vanna[postgres,mysql,clickhouse,duckdb]==0.5.5
|
||||||
|
|
Loading…
Reference in New Issue
Block a user