mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
feat: 异步风格插件方法注册器
This commit is contained in:
parent
fa823de6b0
commit
52a7c25540
|
@ -10,7 +10,6 @@ required_deps = {
|
|||
"botpy": "qq-botpy",
|
||||
"PIL": "pillow",
|
||||
"nakuru": "nakuru-project-idk",
|
||||
"CallingGPT": "CallingGPT",
|
||||
"tiktoken": "tiktoken",
|
||||
"yaml": "pyyaml",
|
||||
"aiohttp": "aiohttp",
|
||||
|
|
|
@ -13,6 +13,17 @@ class BasePlugin(metaclass=abc.ABCMeta):
|
|||
"""插件基类"""
|
||||
|
||||
host: APIHost
|
||||
"""API宿主"""
|
||||
|
||||
ap: app.Application
|
||||
"""应用程序对象"""
|
||||
|
||||
def __init__(self, host: APIHost):
|
||||
self.host = host
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化插件"""
|
||||
pass
|
||||
|
||||
|
||||
class APIHost:
|
||||
|
@ -61,8 +72,10 @@ class EventContext:
|
|||
"""事件编号"""
|
||||
|
||||
host: APIHost = None
|
||||
"""API宿主"""
|
||||
|
||||
event: events.BaseEventModel = None
|
||||
"""此次事件的对象,具体类型为handler注册时指定监听的类型,可查看events.py中的定义"""
|
||||
|
||||
__prevent_default__ = False
|
||||
"""是否阻止默认行为"""
|
||||
|
|
|
@ -10,8 +10,10 @@ from ..provider import entities as llm_entities
|
|||
|
||||
|
||||
class BaseEventModel(pydantic.BaseModel):
|
||||
"""事件模型基类"""
|
||||
|
||||
query: typing.Union[core_entities.Query, None]
|
||||
"""此次请求的query对象,可能为None"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
|
|
@ -5,11 +5,10 @@ import pkgutil
|
|||
import importlib
|
||||
import traceback
|
||||
|
||||
from CallingGPT.entities.namespace import get_func_schema
|
||||
|
||||
from .. import loader, events, context, models, host
|
||||
from ...core import entities as core_entities
|
||||
from ...provider.tools import entities as tools_entities
|
||||
from ...utils import funcschema
|
||||
|
||||
|
||||
class PluginLoader(loader.PluginLoader):
|
||||
|
@ -29,6 +28,9 @@ class PluginLoader(loader.PluginLoader):
|
|||
setattr(models, 'on', self.on)
|
||||
setattr(models, 'func', self.func)
|
||||
|
||||
setattr(models, 'handler', self.handler)
|
||||
setattr(models, 'llm_func', self.llm_func)
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
|
@ -57,6 +59,8 @@ class PluginLoader(loader.PluginLoader):
|
|||
|
||||
return wrapper
|
||||
|
||||
# 过时
|
||||
# 最早将于 v3.4 版本移除
|
||||
def on(
|
||||
self,
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
|
@ -83,6 +87,8 @@ class PluginLoader(loader.PluginLoader):
|
|||
|
||||
return wrapper
|
||||
|
||||
# 过时
|
||||
# 最早将于 v3.4 版本移除
|
||||
def func(
|
||||
self,
|
||||
name: str=None,
|
||||
|
@ -91,10 +97,11 @@ class PluginLoader(loader.PluginLoader):
|
|||
self.ap.logger.debug(f'注册内容函数 {name}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
function_schema = get_func_schema(func)
|
||||
function_schema = funcschema.get_func_schema(func)
|
||||
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||
|
||||
async def handler(
|
||||
plugin: context.BasePlugin,
|
||||
query: core_entities.Query,
|
||||
*args,
|
||||
**kwargs
|
||||
|
@ -116,6 +123,46 @@ class PluginLoader(loader.PluginLoader):
|
|||
|
||||
return wrapper
|
||||
|
||||
def handler(
|
||||
self,
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||
"""注册事件处理器"""
|
||||
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
self._current_container.event_handlers[event] = func
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def llm_func(
|
||||
self,
|
||||
name: str=None,
|
||||
) -> typing.Callable:
|
||||
"""注册内容函数"""
|
||||
self.ap.logger.debug(f'注册内容函数 {name}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
function_schema = funcschema.get_func_schema(func)
|
||||
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||
|
||||
llm_function = tools_entities.LLMFunction(
|
||||
name=function_name,
|
||||
human_desc='',
|
||||
description=function_schema['description'],
|
||||
enable=True,
|
||||
parameters=function_schema['parameters'],
|
||||
func=func,
|
||||
)
|
||||
|
||||
self._current_container.content_functions.append(llm_function)
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
async def _walk_plugin_path(
|
||||
self,
|
||||
module,
|
|
@ -5,7 +5,7 @@ import traceback
|
|||
|
||||
from ..core import app
|
||||
from . import context, loader, events, installer, setting, models
|
||||
from .loaders import legacy
|
||||
from .loaders import classic
|
||||
from .installers import github
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ class PluginManager:
|
|||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.loader = legacy.PluginLoader(ap)
|
||||
self.loader = classic.PluginLoader(ap)
|
||||
self.installer = github.GitHubRepoInstaller(ap)
|
||||
self.setting = setting.SettingManager(ap)
|
||||
self.api_host = context.APIHost(ap)
|
||||
|
@ -52,6 +52,9 @@ class PluginManager:
|
|||
for plugin in self.plugins:
|
||||
try:
|
||||
plugin.plugin_inst = plugin.plugin_class(self.api_host)
|
||||
plugin.plugin_inst.ap = self.ap
|
||||
plugin.plugin_inst.host = self.api_host
|
||||
await plugin.plugin_inst.initialize()
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
|
||||
self.ap.logger.exception(e)
|
||||
|
|
|
@ -24,3 +24,15 @@ def func(
|
|||
name: str=None,
|
||||
) -> typing.Callable:
|
||||
pass
|
||||
|
||||
|
||||
def handler(
|
||||
event: typing.Type[BaseEventModel]
|
||||
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||
pass
|
||||
|
||||
|
||||
def llm_func(
|
||||
name: str=None,
|
||||
) -> typing.Callable:
|
||||
pass
|
|
@ -5,6 +5,7 @@ import traceback
|
|||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
from ...plugin import context as plugin_context
|
||||
|
||||
|
||||
class ToolManager:
|
||||
|
@ -28,6 +29,15 @@ class ToolManager:
|
|||
return function
|
||||
return None
|
||||
|
||||
async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
|
||||
"""获取函数和插件
|
||||
"""
|
||||
for plugin in self.ap.plugin_mgr.plugins:
|
||||
for function in plugin.content_functions:
|
||||
if function.name == name:
|
||||
return function, plugin
|
||||
return None, None
|
||||
|
||||
async def get_all_functions(self) -> list[entities.LLMFunction]:
|
||||
"""获取所有函数
|
||||
"""
|
||||
|
@ -68,7 +78,7 @@ class ToolManager:
|
|||
|
||||
try:
|
||||
|
||||
function = await self.get_function(name)
|
||||
function, plugin = await self.get_function_and_plugin(name)
|
||||
if function is None:
|
||||
return None
|
||||
|
||||
|
@ -79,7 +89,7 @@ class ToolManager:
|
|||
**parameters
|
||||
}
|
||||
|
||||
return await function.func(**parameters)
|
||||
return await function.func(plugin, **parameters)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
|
||||
traceback.print_exc()
|
||||
|
|
116
pkg/utils/funcschema.py
Normal file
116
pkg/utils/funcschema.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import sys
|
||||
import re
|
||||
import inspect
|
||||
|
||||
|
||||
def get_func_schema(function: callable) -> dict:
|
||||
"""
|
||||
Return the data schema of a function.
|
||||
{
|
||||
"function": function,
|
||||
"description": "function description",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parameter_a": {
|
||||
"type": "str",
|
||||
"description": "parameter_a description"
|
||||
},
|
||||
"parameter_b": {
|
||||
"type": "int",
|
||||
"description": "parameter_b description"
|
||||
},
|
||||
"parameter_c": {
|
||||
"type": "str",
|
||||
"description": "parameter_c description",
|
||||
"enum": ["a", "b", "c"]
|
||||
},
|
||||
},
|
||||
"required": ["parameter_a", "parameter_b"]
|
||||
}
|
||||
}
|
||||
"""
|
||||
func_doc = function.__doc__
|
||||
# Google Style Docstring
|
||||
if func_doc is None:
|
||||
raise Exception("Function {} has no docstring.".format(function.__name__))
|
||||
func_doc = func_doc.strip().replace(' ','').replace('\t', '')
|
||||
# extract doc of args from docstring
|
||||
doc_spt = func_doc.split('\n\n')
|
||||
desc = doc_spt[0]
|
||||
args = doc_spt[1] if len(doc_spt) > 1 else ""
|
||||
returns = doc_spt[2] if len(doc_spt) > 2 else ""
|
||||
|
||||
# extract args
|
||||
# delete the first line of args
|
||||
arg_lines = args.split('\n')[1:]
|
||||
arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args)
|
||||
args_doc = {}
|
||||
for arg_line in arg_lines:
|
||||
doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line)
|
||||
if len(doc_tuple) == 0:
|
||||
continue
|
||||
args_doc[doc_tuple[0][0]] = doc_tuple[0][3]
|
||||
|
||||
# extract returns
|
||||
return_doc_list = re.findall(r'(\w+):\s*(.*)', returns)
|
||||
|
||||
params = enumerate(inspect.signature(function).parameters.values())
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"required": [],
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
|
||||
for i, param in params:
|
||||
|
||||
# 排除 self, query
|
||||
if param.name in ['self', 'query']:
|
||||
continue
|
||||
|
||||
param_type = param.annotation.__name__
|
||||
|
||||
type_name_mapping = {
|
||||
"str": "string",
|
||||
"int": "integer",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
"list": "array",
|
||||
"dict": "object",
|
||||
}
|
||||
|
||||
if param_type in type_name_mapping:
|
||||
param_type = type_name_mapping[param_type]
|
||||
|
||||
parameters['properties'][param.name] = {
|
||||
"type": param_type,
|
||||
"description": args_doc[param.name],
|
||||
}
|
||||
|
||||
# add schema for array
|
||||
if param_type == "array":
|
||||
# extract type of array, the int of list[int]
|
||||
# use re
|
||||
array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation))
|
||||
|
||||
array_type = 'string'
|
||||
|
||||
if len(array_type_tuple) > 0:
|
||||
array_type = array_type_tuple[0]
|
||||
|
||||
if array_type in type_name_mapping:
|
||||
array_type = type_name_mapping[array_type]
|
||||
|
||||
parameters['properties'][param.name]["items"] = {
|
||||
"type": array_type,
|
||||
}
|
||||
|
||||
if param.default is inspect.Parameter.empty:
|
||||
parameters["required"].append(param.name)
|
||||
|
||||
return {
|
||||
"function": function,
|
||||
"description": desc,
|
||||
"parameters": parameters,
|
||||
}
|
|
@ -7,7 +7,6 @@ aiocqhttp
|
|||
qq-botpy
|
||||
nakuru-project-idk
|
||||
Pillow
|
||||
CallingGPT
|
||||
tiktoken
|
||||
PyYaml
|
||||
aiohttp
|
||||
|
|
Loading…
Reference in New Issue
Block a user