diff --git a/pkg/core/app.py b/pkg/core/app.py index 3ce86d3..cd9607b 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -19,7 +19,7 @@ from ..utils import version as version_mgr, proxy as proxy_mgr class Application: - im_mgr: im_mgr.QQBotManager = None + im_mgr: im_mgr.PlatformManager = None cmd_mgr: cmdmgr.CommandManager = None diff --git a/pkg/core/boot.py b/pkg/core/boot.py index fedb1e2..8faf743 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -88,10 +88,13 @@ async def make_app() -> app.Application: # 发送公告 ann_mgr = announce.AnnouncementManager(ap) - announcements = await ann_mgr.fetch_new() + try: + announcements = await ann_mgr.fetch_new() - for ann in announcements: - ap.logger.info(f'[公告] {ann.time}: {ann.content}') + for ann in announcements: + ap.logger.info(f'[公告] {ann.time}: {ann.content}') + except Exception as e: + ap.logger.warning(f'获取公告时出错: {e}') ap.query_pool = pool.QueryPool() @@ -99,8 +102,13 @@ async def make_app() -> app.Application: await ver_mgr.initialize() ap.ver_mgr = ver_mgr - if await ap.ver_mgr.is_new_version_available(): - ap.logger.info("有新版本可用,请使用 !update 命令更新") + try: + + if await ap.ver_mgr.is_new_version_available(): + ap.logger.info("有新版本可用,请使用 !update 命令更新") + + except Exception as e: + ap.logger.warning(f"检查版本更新时出错: {e}") plugin_mgr_inst = plugin_mgr.PluginManager(ap) await plugin_mgr_inst.initialize() @@ -141,7 +149,7 @@ async def make_app() -> app.Application: await llm_tool_mgr_inst.initialize() ap.tool_mgr = llm_tool_mgr_inst - im_mgr_inst = im_mgr.QQBotManager(ap=ap) + im_mgr_inst = im_mgr.PlatformManager(ap=ap) await im_mgr_inst.initialize() ap.im_mgr = im_mgr_inst diff --git a/pkg/core/controller.py b/pkg/core/controller.py index eb6241a..40f496a 100644 --- a/pkg/core/controller.py +++ b/pkg/core/controller.py @@ -6,6 +6,7 @@ import traceback from . import app, entities from ..pipeline import entities as pipeline_entities +from ..plugin import events DEFAULT_QUERY_CONCURRENCY = 10 diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 4c7b136..f726b57 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -8,6 +8,7 @@ from .. import handler from ... import entities from ....core import entities as core_entities from ....provider import entities as llm_entities +from ....plugin import events class ChatMessageHandler(handler.MessageHandler): @@ -22,23 +23,152 @@ class ChatMessageHandler(handler.MessageHandler): # 取conversation # 调API # 生成器 - session = await self.ap.sess_mgr.get_session(query) - conversation = await self.ap.sess_mgr.get_conversation(session) + # 触发插件事件 + event_class = events.PersonNormalMessageReceived if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupNormalMessageReceived - conversation.messages.append( - llm_entities.Message( - role="user", - content=str(query.message_chain) + event_ctx = await self.ap.plugin_mgr.emit_event( + event=event_class( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + text_message=str(query.message_chain), + query=query ) ) - async for result in conversation.use_model.requester.request(query, conversation): - conversation.messages.append(result) + if event_ctx.is_prevented_default(): + if event_ctx.event.reply is not None: + query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) - query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))]) + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) + else: - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query + if event_ctx.event.alter is not None: + query.message_chain = mirai.MessageChain([ + mirai.Plain(event_ctx.event.alter) + ]) + + session = await self.ap.sess_mgr.get_session(query) + + conversation = await self.ap.sess_mgr.get_conversation(session) + + # =========== 触发事件 PromptPreProcessing + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.PromptPreProcessing( + session_name=f'{session.launcher_type.value}_{session.launcher_id}', + default_prompt=conversation.prompt.messages, + prompt=conversation.messages, + query=query + ) ) + + conversation.prompt.messages = event_ctx.event.default_prompt + conversation.messages = event_ctx.event.prompt + + conversation.messages.append( + llm_entities.Message( + role="user", + content=str(query.message_chain) + ) + ) + + called_functions = [] + + async for result in conversation.use_model.requester.request(query, conversation): + conversation.messages.append(result) + + # 转换成可读消息 + if result.role == 'assistant': + + reply_text = '' + + if result.content is not None: # 有内容 + reply_text = result.content + + # ============= 触发插件事件 =============== + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.NormalMessageResponded( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + session=session, + prefix='', + response_text=reply_text, + finish_reason='stop', + funcs_called=called_functions, + query=query + ) + ) + if event_ctx.is_prevented_default(): + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) + else: + if event_ctx.event.reply is not None: + + query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + + else: + + query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + if result.tool_calls is not None: # 有函数调用 + + function_names = [tc.function.name for tc in result.tool_calls] + + reply_text = f'调用函数 {".".join(function_names)}...' + + called_functions.extend(function_names) + + query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + + if self.ap.cfg_mgr.data['trace_function_calls']: + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.NormalMessageResponded( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + session=session, + prefix='', + response_text=reply_text, + finish_reason='stop', + funcs_called=called_functions, + query=query + ) + ) + + if event_ctx.is_prevented_default(): + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) + else: + if event_ctx.event.reply is not None: + + query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + + else: + + query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index f836a2a..60543dc 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -6,6 +6,7 @@ import mirai from .. import handler from ... import entities from ....core import entities as core_entities +from ....plugin import events class CommandHandler(handler.MessageHandler): @@ -16,28 +17,34 @@ class CommandHandler(handler.MessageHandler): ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """处理 """ - session = await self.ap.sess_mgr.get_session(query) - command_text = str(query.message_chain).strip()[1:] + event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent - async for ret in self.ap.cmd_mgr.execute( - command_text=command_text, - query=query, - session=session - ): - if ret.error is not None: - query.resp_message_chain = mirai.MessageChain([ - mirai.Plain(str(ret.error)) - ]) - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) - elif ret.text is not None: - query.resp_message_chain = mirai.MessageChain([ - mirai.Plain(ret.text) - ]) + privilege = 1 + if query.sender_id == self.ap.cfg_mgr.data['admin_qq'] \ + or query.sender_id in self.ap.cfg_mgr['admin_qq']: + privilege = 2 + + spt = str(query.message_chain).strip().split(' ') + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=event_class( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + command=spt[0], + params=spt[1:] if len(spt) > 1 else [], + text_message=str(query.message_chain), + is_admin=(privilege==2), + query=query + ) + ) + + if event_ctx.is_prevented_default(): + + if event_ctx.event.reply is not None: + query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -48,3 +55,43 @@ class CommandHandler(handler.MessageHandler): result_type=entities.ResultType.INTERRUPT, new_query=query ) + + else: + + if event_ctx.event.alter is not None: + query.message_chain = mirai.MessageChain([ + mirai.Plain(event_ctx.event.alter) + ]) + + session = await self.ap.sess_mgr.get_session(query) + + command_text = str(query.message_chain).strip()[1:] + + async for ret in self.ap.cmd_mgr.execute( + command_text=command_text, + query=query, + session=session + ): + if ret.error is not None: + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain(str(ret.error)) + ]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + elif ret.text is not None: + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain(ret.text) + ]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 7e8f581..741d3a2 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -18,10 +18,11 @@ from ..platform import adapter as msadapter from .ratelim import ratelim from ..core import app, entities as core_entities +from ..plugin import events # 控制QQ消息输入输出的类 -class QQBotManager: +class PlatformManager: adapter: msadapter.MessageSourceAdapter = None @@ -60,14 +61,26 @@ class QQBotManager: async def on_friend_message(event: FriendMessage): - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.PERSON, - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.PersonMessageReceived( + launcher_type='person', + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_chain=event.message_chain, + query=None + ) ) + if not event_ctx.is_prevented_default(): + + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain + ) + self.adapter.register_listener( FriendMessage, on_friend_message @@ -75,14 +88,26 @@ class QQBotManager: async def on_stranger_message(event: StrangerMessage): - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.PERSON, - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.PersonMessageReceived( + launcher_type='person', + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_chain=event.message_chain, + query=None + ) ) + if not event_ctx.is_prevented_default(): + + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain + ) + # nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件 if config['msg_source_adapter'] == 'yirimirai': self.adapter.register_listener( @@ -92,14 +117,26 @@ class QQBotManager: async def on_group_message(event: GroupMessage): - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.GROUP, - launcher_id=event.group.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.GroupMessageReceived( + launcher_type='person', + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_chain=event.message_chain, + query=None + ) ) + if not event_ctx.is_prevented_default(): + + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.GROUP, + launcher_id=event.group.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain + ) + self.adapter.register_listener( GroupMessage, on_group_message diff --git a/pkg/plugin/events.py b/pkg/plugin/events.py index e5b9fb0..d414daf 100644 --- a/pkg/plugin/events.py +++ b/pkg/plugin/events.py @@ -5,14 +5,13 @@ import typing import pydantic import mirai -from . import context from ..core import entities as core_entities from ..provider import entities as llm_entities class BaseEventModel(pydantic.BaseModel): - query: core_entities.Query + query: core_entities.Query | None class Config: arbitrary_types_allowed = True @@ -142,7 +141,7 @@ class NormalMessageResponded(BaseEventModel): """会话对象""" prefix: str - """回复消息的前缀,可修改""" + """回复消息的前缀""" response_text: str """回复消息的文本""" @@ -157,24 +156,6 @@ class NormalMessageResponded(BaseEventModel): """回复消息组件列表""" -class SessionExplicitReset(BaseEventModel): - """会话被显式重置时触发""" - - session_name: str - - session: core_entities.Session - - -class SessionExpired(BaseEventModel): - """会话过期时触发""" - - session_name: str - - session: core_entities.Session - - session_expire_time: int - - class PromptPreProcessing(BaseEventModel): """会话中的Prompt预处理时触发""" @@ -185,6 +166,3 @@ class PromptPreProcessing(BaseEventModel): prompt: list[llm_entities.Message] """此对话现有消息记录,可修改""" - - text_message: str - """消息文本,可修改""" diff --git a/pkg/plugin/loaders/legacy.py b/pkg/plugin/loaders/legacy.py index 1ba0e54..9bbee7c 100644 --- a/pkg/plugin/loaders/legacy.py +++ b/pkg/plugin/loaders/legacy.py @@ -75,7 +75,7 @@ class PluginLoader(loader.PluginLoader): for k, v in ctx.event.dict().items(): args[k] = v - await func(plugin, **args) + func(plugin, **args) self._current_container.event_handlers[event] = handler diff --git a/pkg/plugin/manager.py b/pkg/plugin/manager.py index 488c6ca..8387937 100644 --- a/pkg/plugin/manager.py +++ b/pkg/plugin/manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +import traceback from ..core import app from . import context, loader, events, installer, setting, models @@ -100,6 +101,9 @@ class PluginManager: for plugin in self.plugins: if plugin.enabled: if event.__class__ in plugin.event_handlers: + + is_prevented_default_before_call = ctx.is_prevented_default() + try: await plugin.event_handlers[event.__class__]( plugin.plugin_inst, @@ -107,12 +111,19 @@ class PluginManager: ) except Exception as e: self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}') - self.ap.logger.exception(e) + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + if not is_prevented_default_before_call and ctx.is_prevented_default(): + self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行') + if ctx.is_prevented_postorder(): self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行') break + + for key in ctx.__return_value__.keys(): + if hasattr(ctx.event, key): + setattr(ctx.event, key, ctx.__return_value__[key][0]) self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}') - return ctx \ No newline at end of file + return ctx diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index c10e09a..972eed1 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing from .context import BasePlugin as Plugin -from . import events +from .events import * def register( name: str, @@ -15,7 +15,7 @@ def register( def on( - event: typing.Type[events.BaseEventModel] + event: typing.Type[BaseEventModel] ) -> typing.Callable[[typing.Callable], typing.Callable]: pass