mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 19:57:04 +08:00
156 lines
4.9 KiB
Python
156 lines
4.9 KiB
Python
from __future__ import annotations
|
|
|
|
import typing
|
|
import pkgutil
|
|
import importlib
|
|
import traceback
|
|
|
|
from CallingGPT.entities.namespace import get_func_schema
|
|
|
|
from .. import loader, events, context, models, host
|
|
from ...core import entities as core_entities
|
|
from ...provider.tools import entities as tools_entities
|
|
|
|
|
|
class PluginLoader(loader.PluginLoader):
|
|
"""加载 plugins/ 目录下的插件"""
|
|
|
|
_current_pkg_path = ''
|
|
|
|
_current_module_path = ''
|
|
|
|
_current_container: context.RuntimeContainer = None
|
|
|
|
containers: list[context.RuntimeContainer] = []
|
|
|
|
async def initialize(self):
|
|
"""初始化"""
|
|
setattr(models, 'register', self.register)
|
|
setattr(models, 'on', self.on)
|
|
setattr(models, 'func', self.func)
|
|
|
|
def register(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
version: str,
|
|
author: str
|
|
) -> typing.Callable[[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]]:
|
|
self.ap.logger.debug(f'注册插件 {name} {version} by {author}')
|
|
container = context.RuntimeContainer(
|
|
plugin_name=name,
|
|
plugin_description=description,
|
|
plugin_version=version,
|
|
plugin_author=author,
|
|
plugin_source='',
|
|
pkg_path=self._current_pkg_path,
|
|
main_file=self._current_module_path,
|
|
event_handlers={},
|
|
content_functions=[],
|
|
)
|
|
|
|
self._current_container = container
|
|
|
|
def wrapper(cls: context.BasePlugin) -> typing.Type[context.BasePlugin]:
|
|
container.plugin_class = cls
|
|
return cls
|
|
|
|
return wrapper
|
|
|
|
def on(
|
|
self,
|
|
event: typing.Type[events.BaseEventModel]
|
|
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
|
"""注册过时的事件处理器"""
|
|
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
|
|
def wrapper(func: typing.Callable) -> typing.Callable:
|
|
|
|
async def handler(plugin: context.BasePlugin, ctx: context.EventContext) -> None:
|
|
args = {
|
|
'host': ctx.host,
|
|
'event': ctx,
|
|
}
|
|
|
|
# 把 ctx.event 所有的属性都放到 args 里
|
|
for k, v in ctx.event.dict().items():
|
|
args[k] = v
|
|
|
|
func(plugin, **args)
|
|
|
|
self._current_container.event_handlers[event] = handler
|
|
|
|
return func
|
|
|
|
return wrapper
|
|
|
|
def func(
|
|
self,
|
|
name: str=None,
|
|
) -> typing.Callable:
|
|
"""注册过时的内容函数"""
|
|
self.ap.logger.debug(f'注册内容函数 {name}')
|
|
def wrapper(func: typing.Callable) -> typing.Callable:
|
|
|
|
function_schema = get_func_schema(func)
|
|
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
|
|
|
async def handler(
|
|
query: core_entities.Query,
|
|
*args,
|
|
**kwargs
|
|
):
|
|
return func(*args, **kwargs)
|
|
|
|
llm_function = tools_entities.LLMFunction(
|
|
name=function_name,
|
|
human_desc='',
|
|
description=function_schema['description'],
|
|
enable=True,
|
|
parameters=function_schema['parameters'],
|
|
func=handler,
|
|
)
|
|
|
|
self._current_container.content_functions.append(llm_function)
|
|
|
|
return func
|
|
|
|
return wrapper
|
|
|
|
async def _walk_plugin_path(
|
|
self,
|
|
module,
|
|
prefix='',
|
|
path_prefix=''
|
|
):
|
|
"""遍历插件路径
|
|
"""
|
|
for item in pkgutil.iter_modules(module.__path__):
|
|
if item.ispkg:
|
|
await self._walk_plugin_path(
|
|
__import__(module.__name__ + "." + item.name, fromlist=[""]),
|
|
prefix + item.name + ".",
|
|
path_prefix + item.name + "/",
|
|
)
|
|
else:
|
|
try:
|
|
self._current_pkg_path = "plugins/" + path_prefix
|
|
self._current_module_path = "plugins/" + path_prefix + item.name + ".py"
|
|
|
|
self._current_container = None
|
|
|
|
importlib.import_module(module.__name__ + "." + item.name)
|
|
|
|
if self._current_container is not None:
|
|
self.containers.append(self._current_container)
|
|
self.ap.logger.debug(f'插件 {self._current_container} 已加载')
|
|
except:
|
|
self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误')
|
|
traceback.print_exc()
|
|
|
|
async def load_plugins(self) -> list[context.RuntimeContainer]:
|
|
"""加载插件
|
|
"""
|
|
await self._walk_plugin_path(__import__("plugins", fromlist=[""]))
|
|
|
|
return self.containers
|