feat: 异步风格插件方法注册器

This commit is contained in:
RockChinQ 2024-03-20 15:09:47 +08:00
parent fa823de6b0
commit 52a7c25540
9 changed files with 210 additions and 9 deletions

View File

@ -10,7 +10,6 @@ required_deps = {
"botpy": "qq-botpy",
"PIL": "pillow",
"nakuru": "nakuru-project-idk",
"CallingGPT": "CallingGPT",
"tiktoken": "tiktoken",
"yaml": "pyyaml",
"aiohttp": "aiohttp",

View File

@ -13,6 +13,17 @@ class BasePlugin(metaclass=abc.ABCMeta):
"""插件基类"""
host: APIHost
"""API宿主"""
ap: app.Application
"""应用程序对象"""
def __init__(self, host: APIHost):
self.host = host
async def initialize(self):
"""初始化插件"""
pass
class APIHost:
@ -61,8 +72,10 @@ class EventContext:
"""事件编号"""
host: APIHost = None
"""API宿主"""
event: events.BaseEventModel = None
"""此次事件的对象具体类型为handler注册时指定监听的类型可查看events.py中的定义"""
__prevent_default__ = False
"""是否阻止默认行为"""

View File

@ -10,8 +10,10 @@ from ..provider import entities as llm_entities
class BaseEventModel(pydantic.BaseModel):
"""事件模型基类"""
query: typing.Union[core_entities.Query, None]
"""此次请求的query对象可能为None"""
class Config:
arbitrary_types_allowed = True

View File

@ -5,11 +5,10 @@ import pkgutil
import importlib
import traceback
from CallingGPT.entities.namespace import get_func_schema
from .. import loader, events, context, models, host
from ...core import entities as core_entities
from ...provider.tools import entities as tools_entities
from ...utils import funcschema
class PluginLoader(loader.PluginLoader):
@ -29,6 +28,9 @@ class PluginLoader(loader.PluginLoader):
setattr(models, 'on', self.on)
setattr(models, 'func', self.func)
setattr(models, 'handler', self.handler)
setattr(models, 'llm_func', self.llm_func)
def register(
self,
name: str,
@ -57,6 +59,8 @@ class PluginLoader(loader.PluginLoader):
return wrapper
# 过时
# 最早将于 v3.4 版本移除
def on(
self,
event: typing.Type[events.BaseEventModel]
@ -83,6 +87,8 @@ class PluginLoader(loader.PluginLoader):
return wrapper
# 过时
# 最早将于 v3.4 版本移除
def func(
self,
name: str=None,
@ -91,10 +97,11 @@ class PluginLoader(loader.PluginLoader):
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = get_func_schema(func)
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
async def handler(
plugin: context.BasePlugin,
query: core_entities.Query,
*args,
**kwargs
@ -116,6 +123,46 @@ class PluginLoader(loader.PluginLoader):
return wrapper
def handler(
self,
event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
def wrapper(func: typing.Callable) -> typing.Callable:
self._current_container.event_handlers[event] = func
return func
return wrapper
def llm_func(
self,
name: str=None,
) -> typing.Callable:
"""注册内容函数"""
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
llm_function = tools_entities.LLMFunction(
name=function_name,
human_desc='',
description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'],
func=func,
)
self._current_container.content_functions.append(llm_function)
return func
return wrapper
async def _walk_plugin_path(
self,
module,

View File

@ -5,7 +5,7 @@ import traceback
from ..core import app
from . import context, loader, events, installer, setting, models
from .loaders import legacy
from .loaders import classic
from .installers import github
@ -26,7 +26,7 @@ class PluginManager:
def __init__(self, ap: app.Application):
self.ap = ap
self.loader = legacy.PluginLoader(ap)
self.loader = classic.PluginLoader(ap)
self.installer = github.GitHubRepoInstaller(ap)
self.setting = setting.SettingManager(ap)
self.api_host = context.APIHost(ap)
@ -52,6 +52,9 @@ class PluginManager:
for plugin in self.plugins:
try:
plugin.plugin_inst = plugin.plugin_class(self.api_host)
plugin.plugin_inst.ap = self.ap
plugin.plugin_inst.host = self.api_host
await plugin.plugin_inst.initialize()
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
self.ap.logger.exception(e)

View File

@ -24,3 +24,15 @@ def func(
name: str=None,
) -> typing.Callable:
pass
def handler(
event: typing.Type[BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
pass
def llm_func(
name: str=None,
) -> typing.Callable:
pass

View File

@ -5,6 +5,7 @@ import traceback
from ...core import app, entities as core_entities
from . import entities
from ...plugin import context as plugin_context
class ToolManager:
@ -28,6 +29,15 @@ class ToolManager:
return function
return None
async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
"""获取函数和插件
"""
for plugin in self.ap.plugin_mgr.plugins:
for function in plugin.content_functions:
if function.name == name:
return function, plugin
return None, None
async def get_all_functions(self) -> list[entities.LLMFunction]:
"""获取所有函数
"""
@ -68,7 +78,7 @@ class ToolManager:
try:
function = await self.get_function(name)
function, plugin = await self.get_function_and_plugin(name)
if function is None:
return None
@ -79,7 +89,7 @@ class ToolManager:
**parameters
}
return await function.func(**parameters)
return await function.func(plugin, **parameters)
except Exception as e:
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
traceback.print_exc()

116
pkg/utils/funcschema.py Normal file
View File

@ -0,0 +1,116 @@
import sys
import re
import inspect
def get_func_schema(function: callable) -> dict:
"""
Return the data schema of a function.
{
"function": function,
"description": "function description",
"parameters": {
"type": "object",
"properties": {
"parameter_a": {
"type": "str",
"description": "parameter_a description"
},
"parameter_b": {
"type": "int",
"description": "parameter_b description"
},
"parameter_c": {
"type": "str",
"description": "parameter_c description",
"enum": ["a", "b", "c"]
},
},
"required": ["parameter_a", "parameter_b"]
}
}
"""
func_doc = function.__doc__
# Google Style Docstring
if func_doc is None:
raise Exception("Function {} has no docstring.".format(function.__name__))
func_doc = func_doc.strip().replace(' ','').replace('\t', '')
# extract doc of args from docstring
doc_spt = func_doc.split('\n\n')
desc = doc_spt[0]
args = doc_spt[1] if len(doc_spt) > 1 else ""
returns = doc_spt[2] if len(doc_spt) > 2 else ""
# extract args
# delete the first line of args
arg_lines = args.split('\n')[1:]
arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args)
args_doc = {}
for arg_line in arg_lines:
doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line)
if len(doc_tuple) == 0:
continue
args_doc[doc_tuple[0][0]] = doc_tuple[0][3]
# extract returns
return_doc_list = re.findall(r'(\w+):\s*(.*)', returns)
params = enumerate(inspect.signature(function).parameters.values())
parameters = {
"type": "object",
"required": [],
"properties": {},
}
for i, param in params:
# 排除 self, query
if param.name in ['self', 'query']:
continue
param_type = param.annotation.__name__
type_name_mapping = {
"str": "string",
"int": "integer",
"float": "number",
"bool": "boolean",
"list": "array",
"dict": "object",
}
if param_type in type_name_mapping:
param_type = type_name_mapping[param_type]
parameters['properties'][param.name] = {
"type": param_type,
"description": args_doc[param.name],
}
# add schema for array
if param_type == "array":
# extract type of array, the int of list[int]
# use re
array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation))
array_type = 'string'
if len(array_type_tuple) > 0:
array_type = array_type_tuple[0]
if array_type in type_name_mapping:
array_type = type_name_mapping[array_type]
parameters['properties'][param.name]["items"] = {
"type": array_type,
}
if param.default is inspect.Parameter.empty:
parameters["required"].append(param.name)
return {
"function": function,
"description": desc,
"parameters": parameters,
}

View File

@ -7,7 +7,6 @@ aiocqhttp
qq-botpy
nakuru-project-idk
Pillow
CallingGPT
tiktoken
PyYaml
aiohttp