mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
feat: 异步风格插件方法注册器
This commit is contained in:
parent
fa823de6b0
commit
52a7c25540
|
@ -10,7 +10,6 @@ required_deps = {
|
||||||
"botpy": "qq-botpy",
|
"botpy": "qq-botpy",
|
||||||
"PIL": "pillow",
|
"PIL": "pillow",
|
||||||
"nakuru": "nakuru-project-idk",
|
"nakuru": "nakuru-project-idk",
|
||||||
"CallingGPT": "CallingGPT",
|
|
||||||
"tiktoken": "tiktoken",
|
"tiktoken": "tiktoken",
|
||||||
"yaml": "pyyaml",
|
"yaml": "pyyaml",
|
||||||
"aiohttp": "aiohttp",
|
"aiohttp": "aiohttp",
|
||||||
|
|
|
@ -13,6 +13,17 @@ class BasePlugin(metaclass=abc.ABCMeta):
|
||||||
"""插件基类"""
|
"""插件基类"""
|
||||||
|
|
||||||
host: APIHost
|
host: APIHost
|
||||||
|
"""API宿主"""
|
||||||
|
|
||||||
|
ap: app.Application
|
||||||
|
"""应用程序对象"""
|
||||||
|
|
||||||
|
def __init__(self, host: APIHost):
|
||||||
|
self.host = host
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""初始化插件"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class APIHost:
|
class APIHost:
|
||||||
|
@ -61,8 +72,10 @@ class EventContext:
|
||||||
"""事件编号"""
|
"""事件编号"""
|
||||||
|
|
||||||
host: APIHost = None
|
host: APIHost = None
|
||||||
|
"""API宿主"""
|
||||||
|
|
||||||
event: events.BaseEventModel = None
|
event: events.BaseEventModel = None
|
||||||
|
"""此次事件的对象,具体类型为handler注册时指定监听的类型,可查看events.py中的定义"""
|
||||||
|
|
||||||
__prevent_default__ = False
|
__prevent_default__ = False
|
||||||
"""是否阻止默认行为"""
|
"""是否阻止默认行为"""
|
||||||
|
|
|
@ -10,8 +10,10 @@ from ..provider import entities as llm_entities
|
||||||
|
|
||||||
|
|
||||||
class BaseEventModel(pydantic.BaseModel):
|
class BaseEventModel(pydantic.BaseModel):
|
||||||
|
"""事件模型基类"""
|
||||||
|
|
||||||
query: typing.Union[core_entities.Query, None]
|
query: typing.Union[core_entities.Query, None]
|
||||||
|
"""此次请求的query对象,可能为None"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
|
@ -5,11 +5,10 @@ import pkgutil
|
||||||
import importlib
|
import importlib
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from CallingGPT.entities.namespace import get_func_schema
|
|
||||||
|
|
||||||
from .. import loader, events, context, models, host
|
from .. import loader, events, context, models, host
|
||||||
from ...core import entities as core_entities
|
from ...core import entities as core_entities
|
||||||
from ...provider.tools import entities as tools_entities
|
from ...provider.tools import entities as tools_entities
|
||||||
|
from ...utils import funcschema
|
||||||
|
|
||||||
|
|
||||||
class PluginLoader(loader.PluginLoader):
|
class PluginLoader(loader.PluginLoader):
|
||||||
|
@ -29,6 +28,9 @@ class PluginLoader(loader.PluginLoader):
|
||||||
setattr(models, 'on', self.on)
|
setattr(models, 'on', self.on)
|
||||||
setattr(models, 'func', self.func)
|
setattr(models, 'func', self.func)
|
||||||
|
|
||||||
|
setattr(models, 'handler', self.handler)
|
||||||
|
setattr(models, 'llm_func', self.llm_func)
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -57,6 +59,8 @@ class PluginLoader(loader.PluginLoader):
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
# 过时
|
||||||
|
# 最早将于 v3.4 版本移除
|
||||||
def on(
|
def on(
|
||||||
self,
|
self,
|
||||||
event: typing.Type[events.BaseEventModel]
|
event: typing.Type[events.BaseEventModel]
|
||||||
|
@ -83,6 +87,8 @@ class PluginLoader(loader.PluginLoader):
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
# 过时
|
||||||
|
# 最早将于 v3.4 版本移除
|
||||||
def func(
|
def func(
|
||||||
self,
|
self,
|
||||||
name: str=None,
|
name: str=None,
|
||||||
|
@ -91,10 +97,11 @@ class PluginLoader(loader.PluginLoader):
|
||||||
self.ap.logger.debug(f'注册内容函数 {name}')
|
self.ap.logger.debug(f'注册内容函数 {name}')
|
||||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||||
|
|
||||||
function_schema = get_func_schema(func)
|
function_schema = funcschema.get_func_schema(func)
|
||||||
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||||
|
|
||||||
async def handler(
|
async def handler(
|
||||||
|
plugin: context.BasePlugin,
|
||||||
query: core_entities.Query,
|
query: core_entities.Query,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
|
@ -116,6 +123,46 @@ class PluginLoader(loader.PluginLoader):
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
def handler(
|
||||||
|
self,
|
||||||
|
event: typing.Type[events.BaseEventModel]
|
||||||
|
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||||
|
"""注册事件处理器"""
|
||||||
|
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
|
||||||
|
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||||
|
|
||||||
|
self._current_container.event_handlers[event] = func
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def llm_func(
|
||||||
|
self,
|
||||||
|
name: str=None,
|
||||||
|
) -> typing.Callable:
|
||||||
|
"""注册内容函数"""
|
||||||
|
self.ap.logger.debug(f'注册内容函数 {name}')
|
||||||
|
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||||
|
|
||||||
|
function_schema = funcschema.get_func_schema(func)
|
||||||
|
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||||
|
|
||||||
|
llm_function = tools_entities.LLMFunction(
|
||||||
|
name=function_name,
|
||||||
|
human_desc='',
|
||||||
|
description=function_schema['description'],
|
||||||
|
enable=True,
|
||||||
|
parameters=function_schema['parameters'],
|
||||||
|
func=func,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._current_container.content_functions.append(llm_function)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
async def _walk_plugin_path(
|
async def _walk_plugin_path(
|
||||||
self,
|
self,
|
||||||
module,
|
module,
|
|
@ -5,7 +5,7 @@ import traceback
|
||||||
|
|
||||||
from ..core import app
|
from ..core import app
|
||||||
from . import context, loader, events, installer, setting, models
|
from . import context, loader, events, installer, setting, models
|
||||||
from .loaders import legacy
|
from .loaders import classic
|
||||||
from .installers import github
|
from .installers import github
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class PluginManager:
|
||||||
|
|
||||||
def __init__(self, ap: app.Application):
|
def __init__(self, ap: app.Application):
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.loader = legacy.PluginLoader(ap)
|
self.loader = classic.PluginLoader(ap)
|
||||||
self.installer = github.GitHubRepoInstaller(ap)
|
self.installer = github.GitHubRepoInstaller(ap)
|
||||||
self.setting = setting.SettingManager(ap)
|
self.setting = setting.SettingManager(ap)
|
||||||
self.api_host = context.APIHost(ap)
|
self.api_host = context.APIHost(ap)
|
||||||
|
@ -52,6 +52,9 @@ class PluginManager:
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins:
|
||||||
try:
|
try:
|
||||||
plugin.plugin_inst = plugin.plugin_class(self.api_host)
|
plugin.plugin_inst = plugin.plugin_class(self.api_host)
|
||||||
|
plugin.plugin_inst.ap = self.ap
|
||||||
|
plugin.plugin_inst.host = self.api_host
|
||||||
|
await plugin.plugin_inst.initialize()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
|
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
|
||||||
self.ap.logger.exception(e)
|
self.ap.logger.exception(e)
|
||||||
|
|
|
@ -24,3 +24,15 @@ def func(
|
||||||
name: str=None,
|
name: str=None,
|
||||||
) -> typing.Callable:
|
) -> typing.Callable:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def handler(
|
||||||
|
event: typing.Type[BaseEventModel]
|
||||||
|
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def llm_func(
|
||||||
|
name: str=None,
|
||||||
|
) -> typing.Callable:
|
||||||
|
pass
|
|
@ -5,6 +5,7 @@ import traceback
|
||||||
|
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app, entities as core_entities
|
||||||
from . import entities
|
from . import entities
|
||||||
|
from ...plugin import context as plugin_context
|
||||||
|
|
||||||
|
|
||||||
class ToolManager:
|
class ToolManager:
|
||||||
|
@ -28,6 +29,15 @@ class ToolManager:
|
||||||
return function
|
return function
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
|
||||||
|
"""获取函数和插件
|
||||||
|
"""
|
||||||
|
for plugin in self.ap.plugin_mgr.plugins:
|
||||||
|
for function in plugin.content_functions:
|
||||||
|
if function.name == name:
|
||||||
|
return function, plugin
|
||||||
|
return None, None
|
||||||
|
|
||||||
async def get_all_functions(self) -> list[entities.LLMFunction]:
|
async def get_all_functions(self) -> list[entities.LLMFunction]:
|
||||||
"""获取所有函数
|
"""获取所有函数
|
||||||
"""
|
"""
|
||||||
|
@ -68,7 +78,7 @@ class ToolManager:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
function = await self.get_function(name)
|
function, plugin = await self.get_function_and_plugin(name)
|
||||||
if function is None:
|
if function is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -79,7 +89,7 @@ class ToolManager:
|
||||||
**parameters
|
**parameters
|
||||||
}
|
}
|
||||||
|
|
||||||
return await function.func(**parameters)
|
return await function.func(plugin, **parameters)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
|
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
116
pkg/utils/funcschema.py
Normal file
116
pkg/utils/funcschema.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
def get_func_schema(function: callable) -> dict:
|
||||||
|
"""
|
||||||
|
Return the data schema of a function.
|
||||||
|
{
|
||||||
|
"function": function,
|
||||||
|
"description": "function description",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"parameter_a": {
|
||||||
|
"type": "str",
|
||||||
|
"description": "parameter_a description"
|
||||||
|
},
|
||||||
|
"parameter_b": {
|
||||||
|
"type": "int",
|
||||||
|
"description": "parameter_b description"
|
||||||
|
},
|
||||||
|
"parameter_c": {
|
||||||
|
"type": "str",
|
||||||
|
"description": "parameter_c description",
|
||||||
|
"enum": ["a", "b", "c"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["parameter_a", "parameter_b"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
func_doc = function.__doc__
|
||||||
|
# Google Style Docstring
|
||||||
|
if func_doc is None:
|
||||||
|
raise Exception("Function {} has no docstring.".format(function.__name__))
|
||||||
|
func_doc = func_doc.strip().replace(' ','').replace('\t', '')
|
||||||
|
# extract doc of args from docstring
|
||||||
|
doc_spt = func_doc.split('\n\n')
|
||||||
|
desc = doc_spt[0]
|
||||||
|
args = doc_spt[1] if len(doc_spt) > 1 else ""
|
||||||
|
returns = doc_spt[2] if len(doc_spt) > 2 else ""
|
||||||
|
|
||||||
|
# extract args
|
||||||
|
# delete the first line of args
|
||||||
|
arg_lines = args.split('\n')[1:]
|
||||||
|
arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args)
|
||||||
|
args_doc = {}
|
||||||
|
for arg_line in arg_lines:
|
||||||
|
doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line)
|
||||||
|
if len(doc_tuple) == 0:
|
||||||
|
continue
|
||||||
|
args_doc[doc_tuple[0][0]] = doc_tuple[0][3]
|
||||||
|
|
||||||
|
# extract returns
|
||||||
|
return_doc_list = re.findall(r'(\w+):\s*(.*)', returns)
|
||||||
|
|
||||||
|
params = enumerate(inspect.signature(function).parameters.values())
|
||||||
|
parameters = {
|
||||||
|
"type": "object",
|
||||||
|
"required": [],
|
||||||
|
"properties": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for i, param in params:
|
||||||
|
|
||||||
|
# 排除 self, query
|
||||||
|
if param.name in ['self', 'query']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param_type = param.annotation.__name__
|
||||||
|
|
||||||
|
type_name_mapping = {
|
||||||
|
"str": "string",
|
||||||
|
"int": "integer",
|
||||||
|
"float": "number",
|
||||||
|
"bool": "boolean",
|
||||||
|
"list": "array",
|
||||||
|
"dict": "object",
|
||||||
|
}
|
||||||
|
|
||||||
|
if param_type in type_name_mapping:
|
||||||
|
param_type = type_name_mapping[param_type]
|
||||||
|
|
||||||
|
parameters['properties'][param.name] = {
|
||||||
|
"type": param_type,
|
||||||
|
"description": args_doc[param.name],
|
||||||
|
}
|
||||||
|
|
||||||
|
# add schema for array
|
||||||
|
if param_type == "array":
|
||||||
|
# extract type of array, the int of list[int]
|
||||||
|
# use re
|
||||||
|
array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation))
|
||||||
|
|
||||||
|
array_type = 'string'
|
||||||
|
|
||||||
|
if len(array_type_tuple) > 0:
|
||||||
|
array_type = array_type_tuple[0]
|
||||||
|
|
||||||
|
if array_type in type_name_mapping:
|
||||||
|
array_type = type_name_mapping[array_type]
|
||||||
|
|
||||||
|
parameters['properties'][param.name]["items"] = {
|
||||||
|
"type": array_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
if param.default is inspect.Parameter.empty:
|
||||||
|
parameters["required"].append(param.name)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"function": function,
|
||||||
|
"description": desc,
|
||||||
|
"parameters": parameters,
|
||||||
|
}
|
|
@ -7,7 +7,6 @@ aiocqhttp
|
||||||
qq-botpy
|
qq-botpy
|
||||||
nakuru-project-idk
|
nakuru-project-idk
|
||||||
Pillow
|
Pillow
|
||||||
CallingGPT
|
|
||||||
tiktoken
|
tiktoken
|
||||||
PyYaml
|
PyYaml
|
||||||
aiohttp
|
aiohttp
|
||||||
|
|
Loading…
Reference in New Issue
Block a user