mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
refactor: 恢复插件事件调用
This commit is contained in:
parent
e2de3d0102
commit
33d600fb6b
|
@ -19,7 +19,7 @@ from ..utils import version as version_mgr, proxy as proxy_mgr
|
|||
|
||||
|
||||
class Application:
|
||||
im_mgr: im_mgr.QQBotManager = None
|
||||
im_mgr: im_mgr.PlatformManager = None
|
||||
|
||||
cmd_mgr: cmdmgr.CommandManager = None
|
||||
|
||||
|
|
|
@ -88,10 +88,13 @@ async def make_app() -> app.Application:
|
|||
|
||||
# 发送公告
|
||||
ann_mgr = announce.AnnouncementManager(ap)
|
||||
announcements = await ann_mgr.fetch_new()
|
||||
try:
|
||||
announcements = await ann_mgr.fetch_new()
|
||||
|
||||
for ann in announcements:
|
||||
ap.logger.info(f'[公告] {ann.time}: {ann.content}')
|
||||
for ann in announcements:
|
||||
ap.logger.info(f'[公告] {ann.time}: {ann.content}')
|
||||
except Exception as e:
|
||||
ap.logger.warning(f'获取公告时出错: {e}')
|
||||
|
||||
ap.query_pool = pool.QueryPool()
|
||||
|
||||
|
@ -99,8 +102,13 @@ async def make_app() -> app.Application:
|
|||
await ver_mgr.initialize()
|
||||
ap.ver_mgr = ver_mgr
|
||||
|
||||
if await ap.ver_mgr.is_new_version_available():
|
||||
ap.logger.info("有新版本可用,请使用 !update 命令更新")
|
||||
try:
|
||||
|
||||
if await ap.ver_mgr.is_new_version_available():
|
||||
ap.logger.info("有新版本可用,请使用 !update 命令更新")
|
||||
|
||||
except Exception as e:
|
||||
ap.logger.warning(f"检查版本更新时出错: {e}")
|
||||
|
||||
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
|
||||
await plugin_mgr_inst.initialize()
|
||||
|
@ -141,7 +149,7 @@ async def make_app() -> app.Application:
|
|||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.QQBotManager(ap=ap)
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.im_mgr = im_mgr_inst
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import traceback
|
|||
|
||||
from . import app, entities
|
||||
from ..pipeline import entities as pipeline_entities
|
||||
from ..plugin import events
|
||||
|
||||
DEFAULT_QUERY_CONCURRENCY = 10
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from .. import handler
|
|||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....provider import entities as llm_entities
|
||||
from ....plugin import events
|
||||
|
||||
|
||||
class ChatMessageHandler(handler.MessageHandler):
|
||||
|
@ -22,23 +23,152 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
# 取conversation
|
||||
# 调API
|
||||
# 生成器
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(session)
|
||||
# 触发插件事件
|
||||
event_class = events.PersonNormalMessageReceived if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupNormalMessageReceived
|
||||
|
||||
conversation.messages.append(
|
||||
llm_entities.Message(
|
||||
role="user",
|
||||
content=str(query.message_chain)
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_class(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
text_message=str(query.message_chain),
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
async for result in conversation.use_model.requester.request(query, conversation):
|
||||
conversation.messages.append(result)
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))])
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
if event_ctx.event.alter is not None:
|
||||
query.message_chain = mirai.MessageChain([
|
||||
mirai.Plain(event_ctx.event.alter)
|
||||
])
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(session)
|
||||
|
||||
# =========== 触发事件 PromptPreProcessing
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PromptPreProcessing(
|
||||
session_name=f'{session.launcher_type.value}_{session.launcher_id}',
|
||||
default_prompt=conversation.prompt.messages,
|
||||
prompt=conversation.messages,
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
conversation.prompt.messages = event_ctx.event.default_prompt
|
||||
conversation.messages = event_ctx.event.prompt
|
||||
|
||||
conversation.messages.append(
|
||||
llm_entities.Message(
|
||||
role="user",
|
||||
content=str(query.message_chain)
|
||||
)
|
||||
)
|
||||
|
||||
called_functions = []
|
||||
|
||||
async for result in conversation.use_model.requester.request(query, conversation):
|
||||
conversation.messages.append(result)
|
||||
|
||||
# 转换成可读消息
|
||||
if result.role == 'assistant':
|
||||
|
||||
reply_text = ''
|
||||
|
||||
if result.content is not None: # 有内容
|
||||
reply_text = result.content
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=called_functions,
|
||||
query=query
|
||||
)
|
||||
)
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
if result.tool_calls is not None: # 有函数调用
|
||||
|
||||
function_names = [tc.function.name for tc in result.tool_calls]
|
||||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
called_functions.extend(function_names)
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
if self.ap.cfg_mgr.data['trace_function_calls']:
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=called_functions,
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
|
|
@ -6,6 +6,7 @@ import mirai
|
|||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....plugin import events
|
||||
|
||||
|
||||
class CommandHandler(handler.MessageHandler):
|
||||
|
@ -16,28 +17,34 @@ class CommandHandler(handler.MessageHandler):
|
|||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
command_text = str(query.message_chain).strip()[1:]
|
||||
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
|
||||
|
||||
async for ret in self.ap.cmd_mgr.execute(
|
||||
command_text=command_text,
|
||||
query=query,
|
||||
session=session
|
||||
):
|
||||
if ret.error is not None:
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain(str(ret.error))
|
||||
])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif ret.text is not None:
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain(ret.text)
|
||||
])
|
||||
privilege = 1
|
||||
if query.sender_id == self.ap.cfg_mgr.data['admin_qq'] \
|
||||
or query.sender_id in self.ap.cfg_mgr['admin_qq']:
|
||||
privilege = 2
|
||||
|
||||
spt = str(query.message_chain).strip().split(' ')
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_class(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
command=spt[0],
|
||||
params=spt[1:] if len(spt) > 1 else [],
|
||||
text_message=str(query.message_chain),
|
||||
is_admin=(privilege==2),
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
@ -48,3 +55,43 @@ class CommandHandler(handler.MessageHandler):
|
|||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
if event_ctx.event.alter is not None:
|
||||
query.message_chain = mirai.MessageChain([
|
||||
mirai.Plain(event_ctx.event.alter)
|
||||
])
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
command_text = str(query.message_chain).strip()[1:]
|
||||
|
||||
async for ret in self.ap.cmd_mgr.execute(
|
||||
command_text=command_text,
|
||||
query=query,
|
||||
session=session
|
||||
):
|
||||
if ret.error is not None:
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain(str(ret.error))
|
||||
])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif ret.text is not None:
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain(ret.text)
|
||||
])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
|
|
|
@ -18,10 +18,11 @@ from ..platform import adapter as msadapter
|
|||
from .ratelim import ratelim
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from ..plugin import events
|
||||
|
||||
|
||||
# 控制QQ消息输入输出的类
|
||||
class QQBotManager:
|
||||
class PlatformManager:
|
||||
|
||||
adapter: msadapter.MessageSourceAdapter = None
|
||||
|
||||
|
@ -60,14 +61,26 @@ class QQBotManager:
|
|||
|
||||
async def on_friend_message(event: FriendMessage):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PersonMessageReceived(
|
||||
launcher_type='person',
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_chain=event.message_chain,
|
||||
query=None
|
||||
)
|
||||
)
|
||||
|
||||
if not event_ctx.is_prevented_default():
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
)
|
||||
|
||||
self.adapter.register_listener(
|
||||
FriendMessage,
|
||||
on_friend_message
|
||||
|
@ -75,14 +88,26 @@ class QQBotManager:
|
|||
|
||||
async def on_stranger_message(event: StrangerMessage):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PersonMessageReceived(
|
||||
launcher_type='person',
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_chain=event.message_chain,
|
||||
query=None
|
||||
)
|
||||
)
|
||||
|
||||
if not event_ctx.is_prevented_default():
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
)
|
||||
|
||||
# nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件
|
||||
if config['msg_source_adapter'] == 'yirimirai':
|
||||
self.adapter.register_listener(
|
||||
|
@ -92,14 +117,26 @@ class QQBotManager:
|
|||
|
||||
async def on_group_message(event: GroupMessage):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.GroupMessageReceived(
|
||||
launcher_type='person',
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_chain=event.message_chain,
|
||||
query=None
|
||||
)
|
||||
)
|
||||
|
||||
if not event_ctx.is_prevented_default():
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
)
|
||||
|
||||
self.adapter.register_listener(
|
||||
GroupMessage,
|
||||
on_group_message
|
||||
|
|
|
@ -5,14 +5,13 @@ import typing
|
|||
import pydantic
|
||||
import mirai
|
||||
|
||||
from . import context
|
||||
from ..core import entities as core_entities
|
||||
from ..provider import entities as llm_entities
|
||||
|
||||
|
||||
class BaseEventModel(pydantic.BaseModel):
|
||||
|
||||
query: core_entities.Query
|
||||
query: core_entities.Query | None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
@ -142,7 +141,7 @@ class NormalMessageResponded(BaseEventModel):
|
|||
"""会话对象"""
|
||||
|
||||
prefix: str
|
||||
"""回复消息的前缀,可修改"""
|
||||
"""回复消息的前缀"""
|
||||
|
||||
response_text: str
|
||||
"""回复消息的文本"""
|
||||
|
@ -157,24 +156,6 @@ class NormalMessageResponded(BaseEventModel):
|
|||
"""回复消息组件列表"""
|
||||
|
||||
|
||||
class SessionExplicitReset(BaseEventModel):
|
||||
"""会话被显式重置时触发"""
|
||||
|
||||
session_name: str
|
||||
|
||||
session: core_entities.Session
|
||||
|
||||
|
||||
class SessionExpired(BaseEventModel):
|
||||
"""会话过期时触发"""
|
||||
|
||||
session_name: str
|
||||
|
||||
session: core_entities.Session
|
||||
|
||||
session_expire_time: int
|
||||
|
||||
|
||||
class PromptPreProcessing(BaseEventModel):
|
||||
"""会话中的Prompt预处理时触发"""
|
||||
|
||||
|
@ -185,6 +166,3 @@ class PromptPreProcessing(BaseEventModel):
|
|||
|
||||
prompt: list[llm_entities.Message]
|
||||
"""此对话现有消息记录,可修改"""
|
||||
|
||||
text_message: str
|
||||
"""消息文本,可修改"""
|
||||
|
|
|
@ -75,7 +75,7 @@ class PluginLoader(loader.PluginLoader):
|
|||
for k, v in ctx.event.dict().items():
|
||||
args[k] = v
|
||||
|
||||
await func(plugin, **args)
|
||||
func(plugin, **args)
|
||||
|
||||
self._current_container.event_handlers[event] = handler
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from ..core import app
|
||||
from . import context, loader, events, installer, setting, models
|
||||
|
@ -100,6 +101,9 @@ class PluginManager:
|
|||
for plugin in self.plugins:
|
||||
if plugin.enabled:
|
||||
if event.__class__ in plugin.event_handlers:
|
||||
|
||||
is_prevented_default_before_call = ctx.is_prevented_default()
|
||||
|
||||
try:
|
||||
await plugin.event_handlers[event.__class__](
|
||||
plugin.plugin_inst,
|
||||
|
@ -107,12 +111,19 @@ class PluginManager:
|
|||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}')
|
||||
self.ap.logger.exception(e)
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
if not is_prevented_default_before_call and ctx.is_prevented_default():
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
|
||||
|
||||
if ctx.is_prevented_postorder():
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行')
|
||||
break
|
||||
|
||||
for key in ctx.__return_value__.keys():
|
||||
if hasattr(ctx.event, key):
|
||||
setattr(ctx.event, key, ctx.__return_value__[key][0])
|
||||
|
||||
self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}')
|
||||
|
||||
return ctx
|
||||
return ctx
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import typing
|
||||
|
||||
from .context import BasePlugin as Plugin
|
||||
from . import events
|
||||
from .events import *
|
||||
|
||||
def register(
|
||||
name: str,
|
||||
|
@ -15,7 +15,7 @@ def register(
|
|||
|
||||
|
||||
def on(
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
event: typing.Type[BaseEventModel]
|
||||
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user