diff --git a/pkg/audit/gatherer.py b/pkg/audit/gatherer.py index 5609384..6e0ea3e 100644 --- a/pkg/audit/gatherer.py +++ b/pkg/audit/gatherer.py @@ -40,6 +40,9 @@ class DataGatherer: except: return + def get_usage(self, key_md5): + return self.usage[key_md5] if key_md5 in self.usage else {} + def report_text_model_usage(self, model, text): key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py index 5add344..78162fa 100644 --- a/pkg/openai/keymgr.py +++ b/pkg/openai/keymgr.py @@ -2,10 +2,8 @@ import hashlib import logging -import pkg.database.manager -import pkg.qqbot.manager -import pkg.utils.context - +import pkg.plugin.host as plugin_host +import pkg.plugin.models as plugin_models class KeysManager: api_key = {} @@ -50,7 +48,16 @@ class KeysManager: for key_name in self.api_key: if self.api_key[key_name] not in self.exceeded: self.using_key = self.api_key[key_name] + logging.info("使用api-key:" + key_name) + + # 触发插件事件 + args = { + "key_name": key_name, + "key_list": self.api_key.keys() + } + _ = plugin_host.emit(plugin_models.KeySwitched, **args) + return True, key_name self.using_key = list(self.api_key.values())[0] diff --git a/pkg/openai/session.py b/pkg/openai/session.py index b773197..fd7a336 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -6,6 +6,9 @@ import pkg.openai.manager import pkg.database.manager import pkg.utils.context +import pkg.plugin.host as plugin_host +import pkg.plugin.models as plugin_models + # 运行时保存的所有session sessions = {} @@ -120,6 +123,17 @@ class Session: config = pkg.utils.context.get_config() if int(time.time()) - self.last_interact_timestamp > config.session_expire_time: logging.info('session {} 已过期'.format(self.name)) + + # 触发插件事件 + args = { + 'session_name': self.name, + 'session': self, + 'session_expire_time': config.session_expire_time + } + event = pkg.plugin.host.emit(plugin_models.SessionExpired, **args) + if event.is_prevented_default(): + return + self.reset(expired=True, schedule_new=False) # 删除此session @@ -131,6 +145,18 @@ class Session: def append(self, text: str) -> str: self.last_interact_timestamp = int(time.time()) + # 触发插件事件 + if self.prompt == self.get_default_prompt(): + args = { + 'session_name': self.name, + 'session': self, + 'default_prompt': self.prompt, + } + + event = pkg.plugin.host.emit(plugin_models.SessionFirstMessage, **args) + if event.is_prevented_default(): + return None + # max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7 config = pkg.utils.context.get_config() max_rounds = 1000 # 不再限制回合数 @@ -220,6 +246,15 @@ class Session: if self.prompt != self.get_default_prompt(): self.persistence() if explicit: + # 触发插件事件 + args = { + 'session_name': self.name, + 'session': self + } + + # 此事件不支持阻止默认行为 + _ = pkg.plugin.host.emit(plugin_models.SessionExplicitReset, **args) + pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp) if expired: diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py index 1e7fd48..87ece13 100644 --- a/pkg/plugin/host.py +++ b/pkg/plugin/host.py @@ -93,9 +93,11 @@ class EventContext: self.__prevent_postorder__ = True def is_prevented_default(self): + """是否阻止默认行为""" return self.__prevent_default__ def is_prevented_postorder(self): + """是否阻止后序插件执行""" return self.__prevent_postorder__ def __init__(self, name: str): @@ -107,6 +109,8 @@ class EventContext: def emit(event_name: str, **kwargs) -> EventContext: """ 触发事件 """ import pkg.utils.context as context + if context.get_plugin_host() is None: + return None return context.get_plugin_host().emit(event_name, **kwargs) diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index 1c0b66a..1cdfe2a 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -39,7 +39,7 @@ PersonCommand = "person_command" launcher_id: int 发起对象ID(群号/QQ号) sender_id: int 发送者ID(QQ号) command: str 指令 - args: list[str] 参数列表 + params: list[str] 参数列表 text_message: str 完整指令文本 is_admin: bool 是否为管理员 """ @@ -60,25 +60,22 @@ GroupCommand = "group_command" launcher_id: int 发起对象ID(群号/QQ号) sender_id: int 发送者ID(QQ号) command: str 指令 - args: list[str] 参数列表 + params: list[str] 参数列表 text_message: str 完整指令文本 + is_admin: bool 是否为管理员 """ SessionFirstMessage = "session_first_message" """会话被第一次交互时触发 kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) session_name: str 会话名称(_) session: pkg.openai.session.Session 会话对象 default_prompt: str 预设值 """ -SessionReset = "session_reset" -"""会话被用户手动重置时触发 +SessionExplicitReset = "session_reset" +"""会话被用户手动重置时触发,此事件不支持阻止默认行为 kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) session_name: str 会话名称(_) session: pkg.openai.session.Session 会话对象 """ @@ -86,11 +83,9 @@ SessionReset = "session_reset" SessionExpired = "session_expired" """会话过期时触发 kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) session_name: str 会话名称(_) session: pkg.openai.session.Session 会话对象 - expired_time: int 已设置的会话过期时间(秒) + session_expire_time: int 已设置的会话过期时间(秒) """ KeyExceeded = "key_exceeded" @@ -98,17 +93,25 @@ KeyExceeded = "key_exceeded" kwargs: key_name: str 超额的api-key名称 usage: dict 超额的api-key使用情况 - exceeded_key: list[str] 超额的api-key列表 + exceeded_keys: list[str] 超额的api-key列表 """ KeySwitched = "key_switched" -"""api-key超额切换成功时触发 +"""api-key超额切换成功时触发,此事件不支持阻止默认行为 kwargs: key_name: str 切换成功的api-key名称 key_list: list[str] api-key列表 """ +def on(event: str): + """注册事件监听器 + :param + event: str 事件名称 + """ + return Plugin.on(event) + + class Plugin: host: host.PluginHost diff --git a/pkg/qqbot/message.py b/pkg/qqbot/message.py index e2db998..ca05eaf 100644 --- a/pkg/qqbot/message.py +++ b/pkg/qqbot/message.py @@ -4,6 +4,8 @@ import openai import pkg.utils.context import pkg.openai.session +import pkg.plugin.host as plugin_host +import pkg.plugin.models as plugin_models def process_normal_message(text_message: str, mgr, config, launcher_type: str, launcher_id: int) -> list: session_name = f"{launcher_type}_{launcher_id}" @@ -26,18 +28,29 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str, l pkg.utils.context.get_openai_manager().key_mgr.using_key ) pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded() - switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch() - if not switched: - mgr.notify_admin( - "api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key".format( - current_key_name)) - reply = ["[bot]err:API调用额度超额,请联系作者,或等待修复"] - else: - openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key() - mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name)) - reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"] - continue + # 触发插件事件 + args = { + 'key_name': current_key_name, + 'usage': pkg.utils.context.get_openai_manager().audit_mgr + .get_usage(pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()), + 'exceeded_keys': pkg.utils.context.get_openai_manager().key_mgr.exceeded, + } + event = plugin_host.emit(plugin_models.KeyExceeded, **args) + + if not event.is_prevented_default(): + switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch() + + if not switched: + mgr.notify_admin( + "api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key".format( + current_key_name)) + reply = ["[bot]err:API调用额度超额,请联系作者,或等待修复"] + else: + openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key() + mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name)) + reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"] + continue except openai.error.InvalidRequestError as e: mgr.notify_admin("{}API调用参数错误:{}\n\n这可能是由于config.py中的prompt_submit_length参数或" "completion_api_params中的max_tokens参数数值过大导致的,请尝试将其降低".format( diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index c17c7ff..4d2fe26 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -1,14 +1,10 @@ # 此模块提供了消息处理的具体逻辑的接口 import asyncio -import datetime -import json -import threading from func_timeout import func_set_timeout import logging -import openai -from mirai import Image, MessageChain, Plain +from mirai import MessageChain, Plain # 这里不使用动态引入config # 因为在这里动态引入会卡死程序 @@ -23,6 +19,9 @@ import pkg.utils.context import pkg.qqbot.message import pkg.qqbot.command +import pkg.plugin.host as plugin_host +import pkg.plugin.models as plugin_models + processing = [] @@ -68,11 +67,39 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes try: if text_message.startswith('!') or text_message.startswith("!"): # 指令 - reply = pkg.qqbot.command.process_command(session_name, text_message, - mgr, config, launcher_type, launcher_id) + # 触发插件事件 + args = { + 'launcher_type': launcher_type, + 'launcher_id': launcher_id, + 'sender_id': sender_id, + 'command': text_message[1:].strip().split(' ')[0], + 'params': text_message[1:].strip().split(' ')[1:], + 'text_message': text_message, + 'is_admin': sender_id is config.admin_qq, + } + event = plugin_host.emit(plugin_models.PersonCommand + if launcher_type == 'person' + else plugin_models.GroupCommand, **args) + + if not event.is_prevented_default(): + reply = pkg.qqbot.command.process_command(session_name, text_message, + mgr, config, launcher_type, launcher_id) else: # 消息 - reply = pkg.qqbot.message.process_normal_message(text_message, mgr, config, launcher_type, launcher_id) + # 触发插件事件 + args = { + "launcher_type": launcher_type, + "launcher_id": launcher_id, + "sender_id": sender_id, + "text_message": text_message, + } + event = plugin_host.emit(plugin_models.PersonNormalMessage + if launcher_type == 'person' + else plugin_models.GroupNormalMessage, **args) + + if not event.is_prevented_default(): + reply = pkg.qqbot.message.process_normal_message(text_message, + mgr, config, launcher_type, launcher_id) if reply is not None and type(reply[0]) == str: logging.info(