feat(plugin): 支持多个事件

This commit is contained in:
Rock Chin 2023-01-14 22:36:48 +08:00
parent 56664f9fbc
commit 6d81821557
7 changed files with 128 additions and 36 deletions

View File

@ -40,6 +40,9 @@ class DataGatherer:
except:
return
def get_usage(self, key_md5):
return self.usage[key_md5] if key_md5 in self.usage else {}
def report_text_model_usage(self, model, text):
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()

View File

@ -2,10 +2,8 @@
import hashlib
import logging
import pkg.database.manager
import pkg.qqbot.manager
import pkg.utils.context
import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
class KeysManager:
api_key = {}
@ -50,7 +48,16 @@ class KeysManager:
for key_name in self.api_key:
if self.api_key[key_name] not in self.exceeded:
self.using_key = self.api_key[key_name]
logging.info("使用api-key:" + key_name)
# 触发插件事件
args = {
"key_name": key_name,
"key_list": self.api_key.keys()
}
_ = plugin_host.emit(plugin_models.KeySwitched, **args)
return True, key_name
self.using_key = list(self.api_key.values())[0]

View File

@ -6,6 +6,9 @@ import pkg.openai.manager
import pkg.database.manager
import pkg.utils.context
import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
# 运行时保存的所有session
sessions = {}
@ -120,6 +123,17 @@ class Session:
config = pkg.utils.context.get_config()
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
logging.info('session {} 已过期'.format(self.name))
# 触发插件事件
args = {
'session_name': self.name,
'session': self,
'session_expire_time': config.session_expire_time
}
event = pkg.plugin.host.emit(plugin_models.SessionExpired, **args)
if event.is_prevented_default():
return
self.reset(expired=True, schedule_new=False)
# 删除此session
@ -131,6 +145,18 @@ class Session:
def append(self, text: str) -> str:
self.last_interact_timestamp = int(time.time())
# 触发插件事件
if self.prompt == self.get_default_prompt():
args = {
'session_name': self.name,
'session': self,
'default_prompt': self.prompt,
}
event = pkg.plugin.host.emit(plugin_models.SessionFirstMessage, **args)
if event.is_prevented_default():
return None
# max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7
config = pkg.utils.context.get_config()
max_rounds = 1000 # 不再限制回合数
@ -220,6 +246,15 @@ class Session:
if self.prompt != self.get_default_prompt():
self.persistence()
if explicit:
# 触发插件事件
args = {
'session_name': self.name,
'session': self
}
# 此事件不支持阻止默认行为
_ = pkg.plugin.host.emit(plugin_models.SessionExplicitReset, **args)
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
if expired:

View File

@ -93,9 +93,11 @@ class EventContext:
self.__prevent_postorder__ = True
def is_prevented_default(self):
"""是否阻止默认行为"""
return self.__prevent_default__
def is_prevented_postorder(self):
"""是否阻止后序插件执行"""
return self.__prevent_postorder__
def __init__(self, name: str):
@ -107,6 +109,8 @@ class EventContext:
def emit(event_name: str, **kwargs) -> EventContext:
""" 触发事件 """
import pkg.utils.context as context
if context.get_plugin_host() is None:
return None
return context.get_plugin_host().emit(event_name, **kwargs)

View File

@ -39,7 +39,7 @@ PersonCommand = "person_command"
launcher_id: int 发起对象ID(群号/QQ号)
sender_id: int 发送者ID(QQ号)
command: str 指令
args: list[str] 参数列表
params: list[str] 参数列表
text_message: str 完整指令文本
is_admin: bool 是否为管理员
"""
@ -60,25 +60,22 @@ GroupCommand = "group_command"
launcher_id: int 发起对象ID(群号/QQ号)
sender_id: int 发送者ID(QQ号)
command: str 指令
args: list[str] 参数列表
params: list[str] 参数列表
text_message: str 完整指令文本
is_admin: bool 是否为管理员
"""
SessionFirstMessage = "session_first_message"
"""会话被第一次交互时触发
kwargs:
launcher_type: str 发起对象类型(group/person)
launcher_id: int 发起对象ID(群号/QQ号)
session_name: str 会话名称(<launcher_type>_<launcher_id>)
session: pkg.openai.session.Session 会话对象
default_prompt: str 预设值
"""
SessionReset = "session_reset"
"""会话被用户手动重置时触发
SessionExplicitReset = "session_reset"
"""会话被用户手动重置时触发,此事件不支持阻止默认行为
kwargs:
launcher_type: str 发起对象类型(group/person)
launcher_id: int 发起对象ID(群号/QQ号)
session_name: str 会话名称(<launcher_type>_<launcher_id>)
session: pkg.openai.session.Session 会话对象
"""
@ -86,11 +83,9 @@ SessionReset = "session_reset"
SessionExpired = "session_expired"
"""会话过期时触发
kwargs:
launcher_type: str 发起对象类型(group/person)
launcher_id: int 发起对象ID(群号/QQ号)
session_name: str 会话名称(<launcher_type>_<launcher_id>)
session: pkg.openai.session.Session 会话对象
expired_time: int 已设置的会话过期时间()
session_expire_time: int 已设置的会话过期时间()
"""
KeyExceeded = "key_exceeded"
@ -98,17 +93,25 @@ KeyExceeded = "key_exceeded"
kwargs:
key_name: str 超额的api-key名称
usage: dict 超额的api-key使用情况
exceeded_key: list[str] 超额的api-key列表
exceeded_keys: list[str] 超额的api-key列表
"""
KeySwitched = "key_switched"
"""api-key超额切换成功时触发
"""api-key超额切换成功时触发,此事件不支持阻止默认行为
kwargs:
key_name: str 切换成功的api-key名称
key_list: list[str] api-key列表
"""
def on(event: str):
"""注册事件监听器
:param
event: str 事件名称
"""
return Plugin.on(event)
class Plugin:
host: host.PluginHost

View File

@ -4,6 +4,8 @@ import openai
import pkg.utils.context
import pkg.openai.session
import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
def process_normal_message(text_message: str, mgr, config, launcher_type: str, launcher_id: int) -> list:
session_name = f"{launcher_type}_{launcher_id}"
@ -26,18 +28,29 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str, l
pkg.utils.context.get_openai_manager().key_mgr.using_key
)
pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded()
switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch()
if not switched:
mgr.notify_admin(
"api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key".format(
current_key_name))
reply = ["[bot]err:API调用额度超额请联系作者或等待修复"]
else:
openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key()
mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name))
reply = ["[bot]err:API调用额度超额已自动切换请重新发送消息"]
continue
# 触发插件事件
args = {
'key_name': current_key_name,
'usage': pkg.utils.context.get_openai_manager().audit_mgr
.get_usage(pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()),
'exceeded_keys': pkg.utils.context.get_openai_manager().key_mgr.exceeded,
}
event = plugin_host.emit(plugin_models.KeyExceeded, **args)
if not event.is_prevented_default():
switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch()
if not switched:
mgr.notify_admin(
"api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key".format(
current_key_name))
reply = ["[bot]err:API调用额度超额请联系作者或等待修复"]
else:
openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key()
mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name))
reply = ["[bot]err:API调用额度超额已自动切换请重新发送消息"]
continue
except openai.error.InvalidRequestError as e:
mgr.notify_admin("{}API调用参数错误:{}\n\n这可能是由于config.py中的prompt_submit_length参数或"
"completion_api_params中的max_tokens参数数值过大导致的请尝试将其降低".format(

View File

@ -1,14 +1,10 @@
# 此模块提供了消息处理的具体逻辑的接口
import asyncio
import datetime
import json
import threading
from func_timeout import func_set_timeout
import logging
import openai
from mirai import Image, MessageChain, Plain
from mirai import MessageChain, Plain
# 这里不使用动态引入config
# 因为在这里动态引入会卡死程序
@ -23,6 +19,9 @@ import pkg.utils.context
import pkg.qqbot.message
import pkg.qqbot.command
import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
processing = []
@ -68,11 +67,39 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
try:
if text_message.startswith('!') or text_message.startswith(""): # 指令
reply = pkg.qqbot.command.process_command(session_name, text_message,
mgr, config, launcher_type, launcher_id)
# 触发插件事件
args = {
'launcher_type': launcher_type,
'launcher_id': launcher_id,
'sender_id': sender_id,
'command': text_message[1:].strip().split(' ')[0],
'params': text_message[1:].strip().split(' ')[1:],
'text_message': text_message,
'is_admin': sender_id is config.admin_qq,
}
event = plugin_host.emit(plugin_models.PersonCommand
if launcher_type == 'person'
else plugin_models.GroupCommand, **args)
if not event.is_prevented_default():
reply = pkg.qqbot.command.process_command(session_name, text_message,
mgr, config, launcher_type, launcher_id)
else: # 消息
reply = pkg.qqbot.message.process_normal_message(text_message, mgr, config, launcher_type, launcher_id)
# 触发插件事件
args = {
"launcher_type": launcher_type,
"launcher_id": launcher_id,
"sender_id": sender_id,
"text_message": text_message,
}
event = plugin_host.emit(plugin_models.PersonNormalMessage
if launcher_type == 'person'
else plugin_models.GroupNormalMessage, **args)
if not event.is_prevented_default():
reply = pkg.qqbot.message.process_normal_message(text_message,
mgr, config, launcher_type, launcher_id)
if reply is not None and type(reply[0]) == str:
logging.info(