refactor: 恢复插件事件调用

This commit is contained in:
RockChinQ 2024-01-30 21:45:17 +08:00
parent e2de3d0102
commit 33d600fb6b
10 changed files with 298 additions and 86 deletions

View File

@ -19,7 +19,7 @@ from ..utils import version as version_mgr, proxy as proxy_mgr
class Application:
im_mgr: im_mgr.QQBotManager = None
im_mgr: im_mgr.PlatformManager = None
cmd_mgr: cmdmgr.CommandManager = None

View File

@ -88,10 +88,13 @@ async def make_app() -> app.Application:
# 发送公告
ann_mgr = announce.AnnouncementManager(ap)
announcements = await ann_mgr.fetch_new()
try:
announcements = await ann_mgr.fetch_new()
for ann in announcements:
ap.logger.info(f'[公告] {ann.time}: {ann.content}')
for ann in announcements:
ap.logger.info(f'[公告] {ann.time}: {ann.content}')
except Exception as e:
ap.logger.warning(f'获取公告时出错: {e}')
ap.query_pool = pool.QueryPool()
@ -99,8 +102,13 @@ async def make_app() -> app.Application:
await ver_mgr.initialize()
ap.ver_mgr = ver_mgr
if await ap.ver_mgr.is_new_version_available():
ap.logger.info("有新版本可用,请使用 !update 命令更新")
try:
if await ap.ver_mgr.is_new_version_available():
ap.logger.info("有新版本可用,请使用 !update 命令更新")
except Exception as e:
ap.logger.warning(f"检查版本更新时出错: {e}")
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
@ -141,7 +149,7 @@ async def make_app() -> app.Application:
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.QQBotManager(ap=ap)
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.im_mgr = im_mgr_inst

View File

@ -6,6 +6,7 @@ import traceback
from . import app, entities
from ..pipeline import entities as pipeline_entities
from ..plugin import events
DEFAULT_QUERY_CONCURRENCY = 10

View File

@ -8,6 +8,7 @@ from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from ....plugin import events
class ChatMessageHandler(handler.MessageHandler):
@ -22,23 +23,152 @@ class ChatMessageHandler(handler.MessageHandler):
# 取conversation
# 调API
# 生成器
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(session)
# 触发插件事件
event_class = events.PersonNormalMessageReceived if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupNormalMessageReceived
conversation.messages.append(
llm_entities.Message(
role="user",
content=str(query.message_chain)
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
text_message=str(query.message_chain),
query=query
)
)
async for result in conversation.use_model.requester.request(query, conversation):
conversation.messages.append(result)
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([
mirai.Plain(event_ctx.event.alter)
])
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(session)
# =========== 触发事件 PromptPreProcessing
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PromptPreProcessing(
session_name=f'{session.launcher_type.value}_{session.launcher_id}',
default_prompt=conversation.prompt.messages,
prompt=conversation.messages,
query=query
)
)
conversation.prompt.messages = event_ctx.event.default_prompt
conversation.messages = event_ctx.event.prompt
conversation.messages.append(
llm_entities.Message(
role="user",
content=str(query.message_chain)
)
)
called_functions = []
async for result in conversation.use_model.requester.request(query, conversation):
conversation.messages.append(result)
# 转换成可读消息
if result.role == 'assistant':
reply_text = ''
if result.content is not None: # 有内容
reply_text = result.content
# ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=called_functions,
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
called_functions.extend(function_names)
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
if self.ap.cfg_mgr.data['trace_function_calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=called_functions,
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@ -6,6 +6,7 @@ import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....plugin import events
class CommandHandler(handler.MessageHandler):
@ -16,28 +17,34 @@ class CommandHandler(handler.MessageHandler):
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
session = await self.ap.sess_mgr.get_session(query)
command_text = str(query.message_chain).strip()[1:]
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
async for ret in self.ap.cmd_mgr.execute(
command_text=command_text,
query=query,
session=session
):
if ret.error is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(str(ret.error))
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif ret.text is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(ret.text)
])
privilege = 1
if query.sender_id == self.ap.cfg_mgr.data['admin_qq'] \
or query.sender_id in self.ap.cfg_mgr['admin_qq']:
privilege = 2
spt = str(query.message_chain).strip().split(' ')
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
command=spt[0],
params=spt[1:] if len(spt) > 1 else [],
text_message=str(query.message_chain),
is_admin=(privilege==2),
query=query
)
)
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
@ -48,3 +55,43 @@ class CommandHandler(handler.MessageHandler):
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([
mirai.Plain(event_ctx.event.alter)
])
session = await self.ap.sess_mgr.get_session(query)
command_text = str(query.message_chain).strip()[1:]
async for ret in self.ap.cmd_mgr.execute(
command_text=command_text,
query=query,
session=session
):
if ret.error is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(str(ret.error))
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif ret.text is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(ret.text)
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)

View File

@ -18,10 +18,11 @@ from ..platform import adapter as msadapter
from .ratelim import ratelim
from ..core import app, entities as core_entities
from ..plugin import events
# 控制QQ消息输入输出的类
class QQBotManager:
class PlatformManager:
adapter: msadapter.MessageSourceAdapter = None
@ -60,14 +61,26 @@ class QQBotManager:
async def on_friend_message(event: FriendMessage):
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PersonMessageReceived(
launcher_type='person',
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_chain=event.message_chain,
query=None
)
)
if not event_ctx.is_prevented_default():
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
)
self.adapter.register_listener(
FriendMessage,
on_friend_message
@ -75,14 +88,26 @@ class QQBotManager:
async def on_stranger_message(event: StrangerMessage):
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PersonMessageReceived(
launcher_type='person',
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_chain=event.message_chain,
query=None
)
)
if not event_ctx.is_prevented_default():
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
)
# nakuru不区分好友和陌生人故仅为yirimirai注册陌生人事件
if config['msg_source_adapter'] == 'yirimirai':
self.adapter.register_listener(
@ -92,14 +117,26 @@ class QQBotManager:
async def on_group_message(event: GroupMessage):
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.GroupMessageReceived(
launcher_type='person',
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_chain=event.message_chain,
query=None
)
)
if not event_ctx.is_prevented_default():
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
)
self.adapter.register_listener(
GroupMessage,
on_group_message

View File

@ -5,14 +5,13 @@ import typing
import pydantic
import mirai
from . import context
from ..core import entities as core_entities
from ..provider import entities as llm_entities
class BaseEventModel(pydantic.BaseModel):
query: core_entities.Query
query: core_entities.Query | None
class Config:
arbitrary_types_allowed = True
@ -142,7 +141,7 @@ class NormalMessageResponded(BaseEventModel):
"""会话对象"""
prefix: str
"""回复消息的前缀,可修改"""
"""回复消息的前缀"""
response_text: str
"""回复消息的文本"""
@ -157,24 +156,6 @@ class NormalMessageResponded(BaseEventModel):
"""回复消息组件列表"""
class SessionExplicitReset(BaseEventModel):
"""会话被显式重置时触发"""
session_name: str
session: core_entities.Session
class SessionExpired(BaseEventModel):
"""会话过期时触发"""
session_name: str
session: core_entities.Session
session_expire_time: int
class PromptPreProcessing(BaseEventModel):
"""会话中的Prompt预处理时触发"""
@ -185,6 +166,3 @@ class PromptPreProcessing(BaseEventModel):
prompt: list[llm_entities.Message]
"""此对话现有消息记录,可修改"""
text_message: str
"""消息文本,可修改"""

View File

@ -75,7 +75,7 @@ class PluginLoader(loader.PluginLoader):
for k, v in ctx.event.dict().items():
args[k] = v
await func(plugin, **args)
func(plugin, **args)
self._current_container.event_handlers[event] = handler

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import typing
import traceback
from ..core import app
from . import context, loader, events, installer, setting, models
@ -100,6 +101,9 @@ class PluginManager:
for plugin in self.plugins:
if plugin.enabled:
if event.__class__ in plugin.event_handlers:
is_prevented_default_before_call = ctx.is_prevented_default()
try:
await plugin.event_handlers[event.__class__](
plugin.plugin_inst,
@ -107,12 +111,19 @@ class PluginManager:
)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}')
self.ap.logger.exception(e)
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
if not is_prevented_default_before_call and ctx.is_prevented_default():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
if ctx.is_prevented_postorder():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行')
break
for key in ctx.__return_value__.keys():
if hasattr(ctx.event, key):
setattr(ctx.event, key, ctx.__return_value__[key][0])
self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}')
return ctx
return ctx

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import typing
from .context import BasePlugin as Plugin
from . import events
from .events import *
def register(
name: str,
@ -15,7 +15,7 @@ def register(
def on(
event: typing.Type[events.BaseEventModel]
event: typing.Type[BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
pass