Merge commit '830ee704da0903a8922dc757381cdf6fd68870a3'

This commit is contained in:
chordfish 2023-03-09 18:34:03 +08:00
commit f448276423
17 changed files with 0 additions and 2205 deletions

View File

@ -1,2 +0,0 @@
"""OpenAI 接口处理及会话管理相关
"""

View File

@ -1,79 +0,0 @@
# 多情景预设值管理
__current__ = "default"
"""当前默认使用的情景预设的名称
由管理员使用`!default <名称>`指令切换
"""
__prompts_from_files__ = {}
"""从文件中读取的情景预设值"""
def read_prompt_from_file() -> str:
"""从文件读取预设值"""
# 读取prompts/目录下的所有文件,以文件名为键,文件内容为值
# 保存在__prompts_from_files__中
global __prompts_from_files__
import os
__prompts_from_files__ = {}
for file in os.listdir("prompts"):
with open(os.path.join("prompts", file), encoding="utf-8") as f:
__prompts_from_files__[file] = f.read()
def get_prompt_dict() -> dict:
"""获取预设值字典"""
import config
default_prompt = config.default_prompt
if type(default_prompt) == str:
default_prompt = {"default": default_prompt}
elif type(default_prompt) == dict:
pass
else:
raise TypeError("default_prompt must be str or dict")
# 将文件中的预设值合并到default_prompt中
for key in __prompts_from_files__:
default_prompt[key] = __prompts_from_files__[key]
return default_prompt
def set_current(name):
global __current__
for key in get_prompt_dict():
if key.lower().startswith(name.lower()):
__current__ = key
return
raise KeyError("未找到情景预设: " + name)
def get_current():
global __current__
return __current__
def set_to_default():
global __current__
default_dict = get_prompt_dict()
if "default" in default_dict:
__current__ = "default"
else:
__current__ = list(default_dict.keys())[0]
def get_prompt(name: str = None) -> str:
"""获取预设值"""
if name is None:
name = get_current()
default_dict = get_prompt_dict()
for key in default_dict:
if key.lower().startswith(name.lower()):
return default_dict[key]
raise KeyError("未找到情景预设: " + name)

View File

@ -1,91 +0,0 @@
# 此模块提供了维护api-key的各种功能
import hashlib
import logging
import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
class KeysManager:
api_key = {}
"""所有api-key"""
using_key = ""
"""当前使用的api-key
"""
alerted = []
"""已提示过超额的key
记录在此以避免重复提示
"""
exceeded = []
"""已超额的key
供自动切换功能识别
"""
def get_using_key(self):
return self.using_key
def get_using_key_md5(self):
return hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
def __init__(self, api_key):
if type(api_key) is dict:
self.api_key = api_key
elif type(api_key) is str:
self.api_key = {
"default": api_key
}
elif type(api_key) is list:
for i in range(len(api_key)):
self.api_key[str(i)] = api_key[i]
# 从usage中删除未加载的api-key的记录
# 不删了也许会运行时添加曾经有记录的api-key
self.auto_switch()
def auto_switch(self) -> (bool, str):
"""尝试切换api-key
Returns:
是否切换成功, 切换后的api-key的别名
"""
for key_name in self.api_key:
if self.api_key[key_name] not in self.exceeded:
self.using_key = self.api_key[key_name]
logging.info("使用api-key:" + key_name)
# 触发插件事件
args = {
"key_name": key_name,
"key_list": self.api_key.keys()
}
_ = plugin_host.emit(plugin_models.KeySwitched, **args)
return True, key_name
self.using_key = list(self.api_key.values())[0]
logging.info("使用api-key:" + list(self.api_key.keys())[0])
return False, ""
def add(self, key_name, key):
self.api_key[key_name] = key
def set_current_exceeded(self):
"""设置当前使用的api-key使用量超限
"""
self.exceeded.append(self.using_key)
def get_key_name(self, api_key):
"""根据api-key获取其别名"""
for key_name in self.api_key:
if self.api_key[key_name] == api_key:
return key_name
return ""

View File

@ -1,93 +0,0 @@
import logging
import openai
import pkg.openai.keymgr
import pkg.utils.context
import pkg.audit.gatherer
from pkg.openai.modelmgr import ModelRequest, create_openai_model_request
class OpenAIInteract:
"""OpenAI 接口封装
将文字接口和图片接口封装供调用方使用
"""
key_mgr: pkg.openai.keymgr.KeysManager = None
audit_mgr: pkg.audit.gatherer.DataGatherer = None
default_image_api_params = {
"size": "256x256",
}
def __init__(self, api_key: str):
self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
self.audit_mgr = pkg.audit.gatherer.DataGatherer()
logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length())
openai.api_key = self.key_mgr.get_using_key()
pkg.utils.context.set_openai_manager(self)
# 请求OpenAI Completion
def request_completion(self, prompts) -> str:
"""请求补全接口回复
Parameters:
prompts (str): 提示语
Returns:
str: 回复
"""
config = pkg.utils.context.get_config()
# 根据模型选择使用的接口
ai: ModelRequest = create_openai_model_request(
config.completion_api_params['model'],
'user',
config.openai_config["http_proxy"] if "http_proxy" in config.openai_config else None
)
ai.request(
prompts,
**config.completion_api_params
)
response = ai.get_response()
logging.debug("OpenAI response: %s", response)
if 'model' in config.completion_api_params:
self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
ai.get_total_tokens())
elif 'engine' in config.completion_api_params:
self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'],
response['usage']['total_tokens'])
return ai.get_message()
def request_image(self, prompt) -> dict:
"""请求图片接口回复
Parameters:
prompt (str): 提示语
Returns:
dict: 响应
"""
config = pkg.utils.context.get_config()
params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params
response = openai.Image.create(
prompt=prompt,
n=1,
**params
)
self.audit_mgr.report_image_model_usage(params['size'])
return response

View File

@ -1,184 +0,0 @@
"""OpenAI 接口底层封装
目前使用的对话接口有
ChatCompletion - gpt-3.5-turbo 等模型
Completion - text-davinci-003 等模型
此模块封装此两个接口的请求实现为上层提供统一的调用方式
"""
import openai, logging, threading, asyncio
import openai.error as aiE
COMPLETION_MODELS = {
'text-davinci-003',
'text-davinci-002',
'code-davinci-002',
'code-cushman-001',
'text-curie-001',
'text-babbage-001',
'text-ada-001',
}
CHAT_COMPLETION_MODELS = {
'gpt-3.5-turbo',
'gpt-3.5-turbo-0301',
}
EDIT_MODELS = {
}
IMAGE_MODELS = {
}
class ModelRequest:
"""模型接口请求父类"""
can_chat = False
runtime: threading.Thread = None
ret = {}
proxy: str = None
request_ready = True
error_info: str = "若在没有任何错误的情况下看到这句话请带着配置文件上报Issues"
def __init__(self, model_name, user_name, request_fun, http_proxy:str = None, time_out = None):
self.model_name = model_name
self.user_name = user_name
self.request_fun = request_fun
self.time_out = time_out
if http_proxy != None:
self.proxy = http_proxy
openai.proxy = self.proxy
self.request_ready = False
async def __a_request__(self, **kwargs):
"""异步请求"""
try:
self.ret:dict = await self.request_fun(**kwargs)
self.request_ready = True
except aiE.APIConnectionError as e:
self.error_info = "{}\n请检查网络连接或代理是否正常".format(e)
raise ConnectionError(self.error_info)
except ValueError as e:
self.error_info = "{}\n该错误可能是由于http_proxy格式设置错误引起的"
except Exception as e:
self.error_info = "{}\n由于请求异常产生的未知错误,请查看日志".format(e)
raise Exception(self.error_info)
def request(self, **kwargs):
"""向接口发起请求"""
if self.proxy != None: #异步请求
self.request_ready = False
loop = asyncio.new_event_loop()
self.runtime = threading.Thread(
target=loop.run_until_complete,
args=(self.__a_request__(**kwargs),)
)
self.runtime.start()
else: #同步请求
self.ret = self.request_fun(**kwargs)
def __msg_handle__(self, msg):
"""将prompt dict转换成接口需要的格式"""
return msg
def ret_handle(self):
'''
API消息返回处理函数
若重写该方法应检查异步线程状态或在需要检查处super该方法
'''
if self.runtime != None and isinstance(self.runtime, threading.Thread):
self.runtime.join(self.time_out)
if self.request_ready:
return
raise Exception(self.error_info)
def get_total_tokens(self):
try:
return self.ret['usage']['total_tokens']
except:
return 0
def get_message(self):
return self.message
def get_response(self):
return self.ret
class ChatCompletionModel(ModelRequest):
"""ChatCompletion接口的请求实现"""
Chat_role = ['system', 'user', 'assistant']
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
if http_proxy == None:
request_fun = openai.ChatCompletion.create
else:
request_fun = openai.ChatCompletion.acreate
self.can_chat = True
super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
def request(self, prompts, **kwargs):
prompts = self.__msg_handle__(prompts)
kwargs['messages'] = prompts
super().request(**kwargs)
self.ret_handle()
def __msg_handle__(self, msgs):
temp_msgs = []
# 把msgs拷贝进temp_msgs
for msg in msgs:
temp_msgs.append(msg.copy())
return temp_msgs
def get_message(self):
return self.ret["choices"][0]["message"]['content'] #需要时直接加载加快请求速度,降低内存消耗
class CompletionModel(ModelRequest):
"""Completion接口的请求实现"""
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
if http_proxy == None:
request_fun = openai.Completion.create
else:
request_fun = openai.Completion.acreate
super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
def request(self, prompts, **kwargs):
prompts = self.__msg_handle__(prompts)
kwargs['prompt'] = prompts
super().request(**kwargs)
self.ret_handle()
def __msg_handle__(self, msgs):
prompt = ''
for msg in msgs:
prompt = prompt + "{}: {}\n".format(msg['role'], msg['content'])
# for msg in msgs:
# if msg['role'] == 'assistant':
# prompt = prompt + "{}\n".format(msg['content'])
# else:
# prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content'])
prompt = prompt + "assistant: "
return prompt
def get_message(self):
return self.ret["choices"][0]["text"]
def create_openai_model_request(model_name: str, user_name: str = 'user', http_proxy:str = None) -> ModelRequest:
"""使用给定的模型名称创建模型请求对象"""
if model_name in CHAT_COMPLETION_MODELS:
model = ChatCompletionModel(model_name, user_name, http_proxy)
elif model_name in COMPLETION_MODELS:
model = CompletionModel(model_name, user_name, http_proxy)
else :
log = "找不到模型[{}],请检查配置文件".format(model_name)
logging.error(log)
raise IndexError(log)
logging.debug("使用接口[{}]创建模型请求[{}]".format(model.__class__.__name__, model_name))
return model

View File

@ -1,28 +0,0 @@
# 计费模块
# 已弃用 https://github.com/RockChinQ/QChatGPT/issues/81
import logging
pricing = {
"base": { # 文字模型单位是1000字符
"text-davinci-003": 0.02,
},
"image": {
"256x256": 0.016,
"512x512": 0.018,
"1024x1024": 0.02,
}
}
def language_base_price(model, text):
salt_rate = 0.93
length = ((len(text.encode('utf-8')) - len(text)) / 2 + len(text)) * salt_rate
logging.debug("text length: %d" % length)
return pricing["base"][model] * length / 1000
def image_price(size):
logging.debug("image size: %s" % size)
return pricing["image"][size]

View File

@ -1,370 +0,0 @@
"""主线使用的会话管理模块
每个人每个群单独一个sessionsession内部保留了对话的上下文
"""
import logging
import threading
import time
import json
import pkg.openai.manager
import pkg.openai.modelmgr
import pkg.database.manager
import pkg.utils.context
import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
# 运行时保存的所有session
sessions = {}
class SessionOfflineStatus:
ON_GOING = 'on_going'
EXPLICITLY_CLOSED = 'explicitly_closed'
# 重置session.prompt
def reset_session_prompt(session_name, prompt):
# 备份原始数据
bak_path = 'logs/{}-{}.bak'.format(
session_name,
time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
)
f = open(bak_path, 'w+')
f.write(prompt)
f.close()
# 生成新数据
config = pkg.utils.context.get_config()
prompt = [
{
'role': 'system',
'content': config.default_prompt['default']
}
]
# 警告
logging.warning(
"""
用户[{}]的数据已被重置有可能是因为数据版本过旧或存储错误
原始数据将备份在
{}""".format(session_name, bak_path)
) # 为保证多行文本格式正确故无缩进
return prompt
# 从数据加载session
def load_sessions():
"""从数据库加载sessions"""
global sessions
db_inst = pkg.utils.context.get_database_manager()
session_data = db_inst.load_valid_sessions()
for session_name in session_data:
logging.info('加载session: {}'.format(session_name))
temp_session = Session(session_name)
temp_session.name = session_name
temp_session.create_timestamp = session_data[session_name]['create_timestamp']
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
try:
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
except Exception:
temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt'])
temp_session.persistence()
sessions[session_name] = temp_session
# 获取指定名称的session如果不存在则创建一个新的
def get_session(session_name: str):
global sessions
if session_name not in sessions:
sessions[session_name] = Session(session_name)
return sessions[session_name]
def dump_session(session_name: str):
global sessions
if session_name in sessions:
assert isinstance(sessions[session_name], Session)
sessions[session_name].persistence()
del sessions[session_name]
# 通用的OpenAI API交互session
# session内部保留了对话的上下文
# 收到用户消息后将上下文提交给OpenAI API生成回复
class Session:
name = ''
prompt = []
"""使用list来保存会话中的回合"""
create_timestamp = 0
"""会话创建时间"""
last_interact_timestamp = 0
"""上次交互(产生回复)时间"""
just_switched_to_exist_session = False
response_lock = None
# 加锁
def acquire_response_lock(self):
logging.debug('{},lock acquire,{}'.format(self.name, self.response_lock))
self.response_lock.acquire()
logging.debug('{},lock acquire successfully,{}'.format(self.name, self.response_lock))
# 释放锁
def release_response_lock(self):
if self.response_lock.locked():
logging.debug('{},lock release,{}'.format(self.name, self.response_lock))
self.response_lock.release()
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
# 从配置文件获取会话预设信息
def get_default_prompt(self, use_default: str = None):
config = pkg.utils.context.get_config()
import pkg.openai.dprompt as dprompt
if use_default is None:
current_default_prompt = dprompt.get_prompt(dprompt.get_current())
else:
current_default_prompt = dprompt.get_prompt(use_default)
return [
{
'role': 'user',
'content': current_default_prompt
}, {
'role': 'assistant',
'content': 'ok'
}
]
def __init__(self, name: str):
self.name = name
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.schedule()
self.response_lock = threading.Lock()
self.prompt = self.get_default_prompt()
# 设定检查session最后一次对话是否超过过期时间的计时器
def schedule(self):
threading.Thread(target=self.expire_check_timer_loop, args=(self.create_timestamp,)).start()
# 检查session是否已经过期
def expire_check_timer_loop(self, create_timestamp: int):
global sessions
while True:
time.sleep(60)
# 不是此session已更换退出
if self.create_timestamp != create_timestamp or self not in sessions.values():
return
config = pkg.utils.context.get_config()
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
logging.info('session {} 已过期'.format(self.name))
# 触发插件事件
args = {
'session_name': self.name,
'session': self,
'session_expire_time': config.session_expire_time
}
event = pkg.plugin.host.emit(plugin_models.SessionExpired, **args)
if event.is_prevented_default():
return
self.reset(expired=True, schedule_new=False)
# 删除此session
del sessions[self.name]
return
# 请求回复
# 这个函数是阻塞的
def append(self, text: str) -> str:
"""向session中添加一条消息返回接口回复"""
self.last_interact_timestamp = int(time.time())
# 触发插件事件
if self.prompt == self.get_default_prompt():
args = {
'session_name': self.name,
'session': self,
'default_prompt': self.prompt,
}
event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args)
if event.is_prevented_default():
return None
config = pkg.utils.context.get_config()
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
# 向API请求补全
message = pkg.utils.context.get_openai_manager().request_completion(
self.cut_out(text, max_length),
)
# 成功获取,处理回复
res_test = message
res_ans = res_test
# 去除开头可能的提示
res_ans_spt = res_test.split("\n\n")
if len(res_ans_spt) > 1:
del (res_ans_spt[0])
res_ans = '\n\n'.join(res_ans_spt)
# 将此次对话的双方内容加入到prompt中
self.prompt.append({'role': 'user', 'content': text})
self.prompt.append({'role': 'assistant', 'content': res_ans})
if self.just_switched_to_exist_session:
self.just_switched_to_exist_session = False
self.set_ongoing()
return res_ans if res_ans[0] != '\n' else res_ans[1:]
# 删除上一回合并返回上一回合的问题
def undo(self) -> str:
self.last_interact_timestamp = int(time.time())
# 删除最后两个消息
if len(self.prompt) < 2:
raise Exception('之前无对话,无法撤销')
question = self.prompt[-2]['content']
self.prompt = self.prompt[:-2]
# 返回上一回合的问题
return question
# 构建对话体
def cut_out(self, msg: str, max_tokens: int) -> list:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens"""
# 如果用户消息长度超过max_tokens直接返回
temp_prompt = [
{
'role': 'user',
'content': msg
}
]
token_count = len(msg)
# 倒序遍历prompt
for i in range(len(self.prompt) - 1, -1, -1):
if token_count >= max_tokens:
break
# 将prompt加到temp_prompt头部
temp_prompt.insert(0, self.prompt[i])
token_count += len(self.prompt[i]['content'])
logging.debug('cut_out: {}'.format(str(temp_prompt)))
return temp_prompt
# 持久化session
def persistence(self):
if self.prompt == self.get_default_prompt():
return
db_inst = pkg.utils.context.get_database_manager()
name_spt = self.name.split('_')
subject_type = name_spt[0]
subject_number = int(name_spt[1])
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
json.dumps(self.prompt))
# 重置session
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None):
if self.prompt[-1]['role'] != "system":
self.persistence()
if explicit:
# 触发插件事件
args = {
'session_name': self.name,
'session': self
}
# 此事件不支持阻止默认行为
_ = pkg.plugin.host.emit(plugin_models.SessionExplicitReset, **args)
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
if expired:
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
self.prompt = self.get_default_prompt(use_prompt)
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.just_switched_to_exist_session = False
# self.response_lock = threading.Lock()
if schedule_new:
self.schedule()
# 将本session的数据库状态设置为on_going
def set_ongoing(self):
pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
# 切换到上一个session
def last_session(self):
last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
if last_one is None:
return None
else:
self.persistence()
self.create_timestamp = last_one['create_timestamp']
self.last_interact_timestamp = last_one['last_interact_timestamp']
try:
self.prompt = json.loads(last_one['prompt'])
except json.decoder.JSONDecodeError:
self.prompt = reset_session_prompt(self.name, last_one['prompt'])
self.persistence()
self.just_switched_to_exist_session = True
return self
# 切换到下一个session
def next_session(self):
next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
if next_one is None:
return None
else:
self.persistence()
self.create_timestamp = next_one['create_timestamp']
self.last_interact_timestamp = next_one['last_interact_timestamp']
try:
self.prompt = json.loads(next_one['prompt'])
except json.decoder.JSONDecodeError:
self.prompt = reset_session_prompt(self.name, next_one['prompt'])
self.persistence()
self.just_switched_to_exist_session = True
return self
def list_history(self, capacity: int = 10, page: int = 0):
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page)
def draw_image(self, prompt: str):
return pkg.utils.context.get_openai_manager().request_image(prompt)

View File

View File

@ -1,50 +0,0 @@
import pkg.utils.context
def is_banned(launcher_type: str, launcher_id: int, sender_id: int) -> bool:
if not pkg.utils.context.get_qqbot_manager().enable_banlist:
return False
result = False
if launcher_type == 'group':
# 检查是否显式声明发起人QQ要被person忽略
if sender_id in pkg.utils.context.get_qqbot_manager().ban_person:
result = True
else:
for group_rule in pkg.utils.context.get_qqbot_manager().ban_group:
if type(group_rule) == int:
if group_rule == launcher_id: # 此群群号被禁用
result = True
elif type(group_rule) == str:
if group_rule.startswith('!'):
# 截取!后面的字符串作为表达式,判断是否匹配
reg_str = group_rule[1:]
import re
if re.match(reg_str, str(launcher_id)): # 被豁免,最高级别
result = False
break
else:
# 判断是否匹配regexp
import re
if re.match(group_rule, str(launcher_id)): # 此群群号被禁用
result = True
else:
# ban_person, 与群规则相同
for person_rule in pkg.utils.context.get_qqbot_manager().ban_person:
if type(person_rule) == int:
if person_rule == launcher_id:
result = True
elif type(person_rule) == str:
if person_rule.startswith('!'):
reg_str = person_rule[1:]
import re
if re.match(reg_str, str(launcher_id)):
result = False
break
else:
import re
if re.match(person_rule, str(launcher_id)):
result = True
return result

View File

@ -1,105 +0,0 @@
# 长消息处理相关
import logging
import os
import time
import base64
import config
from mirai.models.message import MessageComponent, MessageChain, Image
from mirai.models.message import ForwardMessageNode
from mirai.models.base import MiraiBaseModel
from typing import List
import pkg.utils.context as context
import pkg.utils.text2img as text2img
class ForwardMessageDiaplay(MiraiBaseModel):
title: str = "群聊的聊天记录"
brief: str = "[聊天记录]"
source: str = "聊天记录"
preview: List[str] = []
summary: str = "查看x条转发消息"
class Forward(MessageComponent):
"""合并转发。"""
type: str = "Forward"
"""消息组件类型。"""
display: ForwardMessageDiaplay
"""显示信息"""
node_list: List[ForwardMessageNode]
"""转发消息节点列表。"""
def __init__(self, *args, **kwargs):
if len(args) == 1:
self.node_list = args[0]
super().__init__(**kwargs)
super().__init__(*args, **kwargs)
def __str__(self):
return '[聊天记录]'
def text_to_image(text: str) -> MessageComponent:
"""将文本转换成图片"""
# 检查temp文件夹是否存在
if not os.path.exists('temp'):
os.mkdir('temp')
img_path = text2img.text_to_image(text_str=text, save_as='temp/{}.png'.format(int(time.time())))
compressed_path, size = text2img.compress_image(img_path, outfile="temp/{}_compressed.png".format(int(time.time())))
# 读取图片转换成base64
with open(compressed_path, 'rb') as f:
img = f.read()
b64 = base64.b64encode(img)
# 删除图片
os.remove(img_path)
# 判断compressed_path是否存在
if os.path.exists(compressed_path):
os.remove(compressed_path)
# 返回图片
return Image(base64=b64.decode('utf-8'))
def check_text(text: str) -> list:
"""检查文本是否为长消息,并转换成该使用的消息链组件"""
if not hasattr(config, 'blob_message_threshold'):
return [text]
if len(text) > config.blob_message_threshold:
if not hasattr(config, 'blob_message_strategy'):
raise AttributeError('未定义长消息处理策略')
# logging.info("长消息: {}".format(text))
if config.blob_message_strategy == 'image':
# 转换成图片
return [text_to_image(text)]
elif config.blob_message_strategy == 'forward':
# 敏感词屏蔽
text = context.get_qqbot_manager().reply_filter.process(text)
# 包装转发消息
display = ForwardMessageDiaplay(
title='群聊的聊天记录',
brief='[聊天记录]',
source='聊天记录',
preview=["bot: "+text],
summary="查看1条转发消息"
)
node = ForwardMessageNode(
sender_id=config.mirai_http_api_config['qq'],
sender_name='bot',
message_chain=MessageChain([text])
)
forward = Forward(
display=display,
node_list=[node]
)
return [forward]
else:
return [text]

View File

@ -1,359 +0,0 @@
# 指令处理模块
import logging
import json
import datetime
import os
import threading
import pkg.openai.session
import pkg.openai.manager
import pkg.utils.reloader
import pkg.utils.updater
import pkg.utils.context
import pkg.qqbot.message
import pkg.utils.credit as credit
from mirai import Image
def config_operation(cmd, params):
reply = []
config = pkg.utils.context.get_config()
reply_str = ""
if len(params) == 0:
reply = ["[bot]err:请输入配置项"]
else:
cfg_name = params[0]
if cfg_name == 'all':
reply_str = "[bot]所有配置项:\n\n"
for cfg in dir(config):
if not cfg.startswith('__') and not cfg == 'logging':
# 根据配置项类型进行格式化如果是字典则转换为json并格式化
if isinstance(getattr(config, cfg), str):
reply_str += "{}: \"{}\"\n".format(cfg, getattr(config, cfg))
elif isinstance(getattr(config, cfg), dict):
# 不进行unicode转义并格式化
reply_str += "{}: {}\n".format(cfg,
json.dumps(getattr(config, cfg),
ensure_ascii=False, indent=4))
else:
reply_str += "{}: {}\n".format(cfg, getattr(config, cfg))
reply = [reply_str]
elif cfg_name in dir(config):
if len(params) == 1:
# 按照配置项类型进行格式化
if isinstance(getattr(config, cfg_name), str):
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, getattr(config, cfg_name))
elif isinstance(getattr(config, cfg_name), dict):
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
json.dumps(getattr(config, cfg_name),
ensure_ascii=False, indent=4))
else:
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, getattr(config, cfg_name))
reply = [reply_str]
else:
cfg_value = " ".join(params[1:])
# 类型转换如果是json则转换为字典
if cfg_value == 'true':
cfg_value = True
elif cfg_value == 'false':
cfg_value = False
elif cfg_value.isdigit():
cfg_value = int(cfg_value)
elif cfg_value.startswith('{') and cfg_value.endswith('}'):
cfg_value = json.loads(cfg_value)
else:
try:
cfg_value = float(cfg_value)
except ValueError:
pass
# 检查类型是否匹配
if isinstance(getattr(config, cfg_name), type(cfg_value)):
setattr(config, cfg_name, cfg_value)
pkg.utils.context.set_config(config)
reply = ["[bot]配置项{}修改成功".format(cfg_name)]
else:
reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
else:
reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
return reply
def plugin_operation(cmd, params, is_admin):
reply = []
import pkg.plugin.host as plugin_host
import pkg.utils.updater as updater
plugin_list = plugin_host.__plugins__
if len(params) == 0:
reply_str = "[bot]所有插件({}):\n".format(len(plugin_host.__plugins__))
idx = 0
for key in plugin_host.iter_plugins_name():
plugin = plugin_list[key]
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
.format((idx+1), plugin['name'],
"[已禁用]" if not plugin['enabled'] else "",
plugin['description'],
plugin['version'], plugin['author'])
if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1]))
if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT":
reply_str += "源码: "+remote_url+"\n"
idx += 1
reply = [reply_str]
elif params[0] == 'update':
# 更新所有插件
if is_admin:
def closure():
import pkg.utils.context
updated = []
for key in plugin_list:
plugin = plugin_list[key]
if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
success = updater.pull_latest("/".join(plugin['path'].split('/')[:-1]))
if success:
updated.append(plugin['name'])
# 检查是否有requirements.txt
pkg.utils.context.get_qqbot_manager().notify_admin("正在安装依赖...")
for key in plugin_list:
plugin = plugin_list[key]
if os.path.exists("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt"):
logging.info("{}检测到requirements.txt安装依赖".format(plugin['name']))
import pkg.utils.pkgmgr
pkg.utils.pkgmgr.install_requirements("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt")
import main
main.reset_logging()
pkg.utils.context.get_qqbot_manager().notify_admin("已更新插件: {}".format(", ".join(updated)))
threading.Thread(target=closure).start()
reply = ["[bot]正在更新所有插件,请勿重复发起..."]
else:
reply = ["[bot]err:权限不足"]
elif params[0].startswith("http"):
if is_admin:
def closure():
try:
plugin_host.install_plugin(params[0])
pkg.utils.context.get_qqbot_manager().notify_admin("插件安装成功,请发送 !reload 指令重载插件")
except Exception as e:
logging.error("插件安装失败:{}".format(e))
pkg.utils.context.get_qqbot_manager().notify_admin("插件安装失败:{}".format(e))
threading.Thread(target=closure, args=()).start()
reply = ["[bot]正在安装插件..."]
else:
reply = ["[bot]err:权限不足,请使用管理员账号私聊发起"]
return reply
def process_command(session_name: str, text_message: str, mgr, config,
launcher_type: str, launcher_id: int, sender_id: int, is_admin: bool) -> list:
reply = []
try:
logging.info(
"[{}]发起指令:{}".format(session_name, text_message[:min(20, len(text_message))] + (
"..." if len(text_message) > 20 else "")))
cmd = text_message[1:].strip().split(' ')[0]
params = text_message[1:].strip().split(' ')[1:]
if cmd == 'help':
reply = ["[bot]" + config.help_message]
elif cmd == 'reset':
if len(params) == 0:
pkg.openai.session.get_session(session_name).reset(explicit=True)
reply = ["[bot]会话已重置"]
else:
pkg.openai.session.get_session(session_name).reset(explicit=True, use_prompt=params[0])
reply = ["[bot]会话已重置,使用场景预设:{}".format(params[0])]
elif cmd == 'last':
result = pkg.openai.session.get_session(session_name).last_session()
if result is None:
reply = ["[bot]没有前一次的对话"]
else:
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
'%Y-%m-%d %H:%M:%S')
reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(datetime_str)]
elif cmd == 'next':
result = pkg.openai.session.get_session(session_name).next_session()
if result is None:
reply = ["[bot]没有后一次的对话"]
else:
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
'%Y-%m-%d %H:%M:%S')
reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(datetime_str)]
elif cmd == 'prompt':
msgs = ""
session:list = pkg.openai.session.get_session(session_name).prompt
for msg in session:
if len(params) != 0 and params[0] in ['-all', '-a']:
msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content'])
elif len(msg['content']) > 30:
msgs = msgs + "[{}]: {}...\n\n".format(msg['role'], msg['content'][:30])
else:
msgs = msgs + "[{}]: {}\n\n".format(msg['role'], msg['content'])
reply = ["[bot]当前对话所有内容:\n{}".format(msgs)]
elif cmd == 'list':
pkg.openai.session.get_session(session_name).persistence()
page = 0
if len(params) > 0:
try:
page = int(params[0])
except ValueError:
pass
results = pkg.openai.session.get_session(session_name).list_history(page=page)
if len(results) == 0:
reply = ["[bot]第{}页没有历史会话".format(page)]
else:
reply_str = "[bot]历史会话 第{}页:\n".format(page)
current = -1
for i in range(len(results)):
# 时间(使用create_timestamp转换) 序号 部分内容
datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp'])
msg = ""
try:
msg = json.loads(results[i]['prompt'])
except json.decoder.JSONDecodeError:
msg = pkg.openai.session.reset_session_prompt(session_name, results[i]['prompt'])
# 持久化
pkg.openai.session.get_session(session_name).persistence()
if len(msg) >= 2:
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
msg[1]['content'])
else:
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
"无内容")
if results[i]['create_timestamp'] == pkg.openai.session.get_session(
session_name).create_timestamp:
current = i + page * 10
reply_str += "\n以上信息倒序排列"
if current != -1:
reply_str += ",当前会话是 #{}\n".format(current)
else:
reply_str += ",当前处于全新会话或不在此页"
reply = [reply_str]
elif cmd == 'resend':
session = pkg.openai.session.get_session(session_name)
to_send = session.undo()
reply = pkg.qqbot.message.process_normal_message(to_send, mgr, config,
launcher_type, launcher_id, sender_id)
elif cmd == 'usage':
reply_str = "[bot]各api-key使用情况:\n\n"
api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key
for key_name in api_keys:
text_length = pkg.utils.context.get_openai_manager().audit_mgr \
.get_text_length_of_key(api_keys[key_name])
image_count = pkg.utils.context.get_openai_manager().audit_mgr \
.get_image_count_of_key(api_keys[key_name])
reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length),
int(image_count))
# 获取此key的额度
try:
credit_data = credit.fetch_credit_data(api_keys[key_name])
reply_str += " - 使用额度:{:.2f}/{:.2f}\n".format(credit_data['total_used'],credit_data['total_granted'])
except Exception as e:
logging.warning("获取额度失败:{}".format(e))
reply = [reply_str]
elif cmd == 'draw':
if len(params) == 0:
reply = ["[bot]err:请输入图片描述文字"]
else:
session = pkg.openai.session.get_session(session_name)
res = session.draw_image(" ".join(params))
logging.debug("draw_image result:{}".format(res))
reply = [Image(url=res['data'][0]['url'])]
if not (hasattr(config, 'include_image_description')
and not config.include_image_description):
reply.append(" ".join(params))
elif cmd == 'version':
reply_str = "[bot]当前版本:\n{}\n".format(pkg.utils.updater.get_current_version_info())
try:
if pkg.utils.updater.is_new_version_available():
reply_str += "\n有新版本可用,请使用命令 !update 进行更新"
except:
pass
reply = [reply_str]
elif cmd == 'plugin':
reply = plugin_operation(cmd, params, is_admin)
elif cmd == 'default':
if len(params) == 0:
# 输出目前所有情景预设
import pkg.openai.dprompt as dprompt
reply_str = "[bot]当前所有情景预设:\n\n"
for key,value in dprompt.get_prompt_dict().items():
reply_str += " - {}: {}\n".format(key,value)
reply_str += "\n当前默认情景预设:{}\n".format(dprompt.get_current())
reply_str += "请使用!default <情景预设>来设置默认情景预设"
reply = [reply_str]
elif len(params) >0 and is_admin:
# 设置默认情景
import pkg.openai.dprompt as dprompt
try:
dprompt.set_current(params[0])
reply = ["[bot]已设置默认情景预设为:{}".format(dprompt.get_current())]
except KeyError:
reply = ["[bot]err: 未找到情景预设:{}".format(params[0])]
else:
reply = ["[bot]err: 仅管理员可设置默认情景预设"]
elif cmd == 'reload' and is_admin:
def reload_task():
pkg.utils.reloader.reload_all()
threading.Thread(target=reload_task, daemon=True).start()
elif cmd == 'update' and is_admin:
def update_task():
try:
if pkg.utils.updater.update_all():
pkg.utils.reloader.reload_all(notify=False)
pkg.utils.context.get_qqbot_manager().notify_admin("更新完成")
else:
pkg.utils.context.get_qqbot_manager().notify_admin("无新版本")
except Exception as e0:
pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0))
return
threading.Thread(target=update_task, daemon=True).start()
reply = ["[bot]正在更新,请耐心等待,请勿重复发起更新..."]
elif cmd == 'cfg' and is_admin:
reply = config_operation(cmd, params)
else:
if cmd.startswith("~") and is_admin:
config_item = cmd[1:]
params = [config_item] + params
reply = config_operation("cfg", params)
else:
reply = ["[bot]err:未知的指令或权限不足: " + cmd]
except Exception as e:
mgr.notify_admin("{}指令执行失败:{}".format(session_name, e))
logging.exception(e)
reply = ["[bot]err:{}".format(e)]
return reply

View File

@ -1,84 +0,0 @@
# 敏感词过滤模块
import re
import requests
import json
import logging
class ReplyFilter:
sensitive_words = []
mask = "*"
mask_word = ""
# 默认值( 兼容性考虑 )
baidu_check = False
baidu_api_key = ""
baidu_secret_key = ""
inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规"
def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""):
self.sensitive_words = sensitive_words
self.mask = mask
self.mask_word = mask_word
import config
if hasattr(config, 'baidu_check') and hasattr(config, 'baidu_api_key') and hasattr(config, 'baidu_secret_key'):
self.baidu_check = config.baidu_check
self.baidu_api_key = config.baidu_api_key
self.baidu_secret_key = config.baidu_secret_key
self.inappropriate_message_tips = config.inappropriate_message_tips
def is_illegal(self, message: str) -> bool:
processed = self.process(message)
if processed != message:
return True
return False
def process(self, message: str) -> str:
# 本地关键词屏蔽
for word in self.sensitive_words:
match = re.findall(word, message)
if len(match) > 0:
for i in range(len(match)):
if self.mask_word == "":
message = message.replace(match[i], self.mask * len(match[i]))
else:
message = message.replace(match[i], self.mask_word)
# 百度云审核
if self.baidu_check:
# 百度云审核URL
baidu_url = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token=" + \
str(requests.post("https://aip.baidubce.com/oauth/2.0/token",
params={"grant_type": "client_credentials",
"client_id": self.baidu_api_key,
"client_secret": self.baidu_secret_key}).json().get("access_token"))
# 百度云审核
payload = "text=" + message
logging.info("向百度云发送:" + payload)
headers = {'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'}
if isinstance(payload, str):
payload = payload.encode('utf-8')
response = requests.request("POST", baidu_url, headers=headers, data=payload)
response_dict = json.loads(response.text)
if "error_code" in response_dict:
error_msg = response_dict.get("error_msg")
logging.warning(f"百度云判定出错,错误信息:{error_msg}")
conclusion = f"百度云判定出错,错误信息:{error_msg}\n以下是原消息:{message}"
else:
conclusion = response_dict["conclusion"]
if conclusion in ("合规"):
logging.info(f"百度云判定结果:{conclusion}")
return message
else:
logging.warning(f"百度云判定结果:{conclusion}")
conclusion = self.inappropriate_message_tips
# 返回百度云审核结果
return conclusion
return message

View File

@ -1,19 +0,0 @@
import re
def ignore(msg: str) -> bool:
"""检查消息是否应该被忽略"""
import config
if not hasattr(config, 'ignore_rules'):
return False
if 'prefix' in config.ignore_rules:
for rule in config.ignore_rules['prefix']:
if msg.startswith(rule):
return True
if 'regexp' in config.ignore_rules:
for rule in config.ignore_rules['regexp']:
if re.search(rule, msg):
return True

View File

@ -1,357 +0,0 @@
import asyncio
import json
import os
import threading
from concurrent.futures import ThreadPoolExecutor
import mirai.models.bus
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
FriendMessage, Image
from func_timeout import func_set_timeout
import pkg.openai.session
import pkg.openai.manager
from func_timeout import FunctionTimedOut
import logging
import pkg.qqbot.filter
import pkg.qqbot.process as processor
import pkg.utils.context
import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
# 检查消息是否符合泛响应匹配机制
def check_response_rule(text: str):
config = pkg.utils.context.get_config()
if not hasattr(config, 'response_rules'):
return False, ''
rules = config.response_rules
# 检查前缀匹配
if 'prefix' in rules:
for rule in rules['prefix']:
if text.startswith(rule):
return True, text.replace(rule, "", 1)
# 检查正则表达式匹配
if 'regexp' in rules:
for rule in rules['regexp']:
import re
match = re.match(rule, text)
if match:
return True, text
return False, ""
def response_at():
config = pkg.utils.context.get_config()
if 'at' not in config.response_rules:
return True
return config.response_rules['at']
def random_responding():
config = pkg.utils.context.get_config()
if 'random_rate' in config.response_rules:
import random
return random.random() < config.response_rules['random_rate']
return False
# 控制QQ消息输入输出的类
class QQBotManager:
retry = 3
#线程池控制
pool = None
bot: Mirai = None
reply_filter = None
enable_banlist = False
ban_person = []
ban_group = []
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True):
self.timeout = timeout
self.retry = retry
self.pool_num = pool_num
self.pool = ThreadPoolExecutor(max_workers=self.pool_num)
logging.debug("Registered thread pool Size:{}".format(pool_num))
# 加载禁用列表
if os.path.exists("banlist.py"):
import banlist
self.enable_banlist = banlist.enable
self.ban_person = banlist.person
self.ban_group = banlist.group
logging.info("加载禁用列表: person: {}, group: {}".format(self.ban_person, self.ban_group))
config = pkg.utils.context.get_config()
if os.path.exists("sensitive.json") \
and config.sensitive_word_filter is not None \
and config.sensitive_word_filter:
with open("sensitive.json", "r", encoding="utf-8") as f:
sensitive_json = json.load(f)
self.reply_filter = pkg.qqbot.filter.ReplyFilter(
sensitive_words=sensitive_json['words'],
mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*',
mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else ''
)
else:
self.reply_filter = pkg.qqbot.filter.ReplyFilter([])
# 由于YiriMirai的bot对象是单例的且shutdown方法暂时无法使用
# 故只在第一次初始化时创建bot对象重载之后使用原bot对象
# 因此bot的配置不支持热重载
if first_time_init:
self.first_time_init(mirai_http_api_config)
else:
self.bot = pkg.utils.context.get_qqbot_manager().bot
pkg.utils.context.set_qqbot_manager(self)
# Caution: 注册新的事件处理器之后请务必在unsubscribe_all中编写相应的取消订阅代码
@self.bot.on(FriendMessage)
async def on_friend_message(event: FriendMessage):
def friend_message_handler(event: FriendMessage):
# 触发事件
args = {
"launcher_type": "person",
"launcher_id": event.sender.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
if plugin_event.is_prevented_default():
return
self.on_person_message(event)
self.go(friend_message_handler, event)
@self.bot.on(StrangerMessage)
async def on_stranger_message(event: StrangerMessage):
def stranger_message_handler(event: StrangerMessage):
# 触发事件
args = {
"launcher_type": "person",
"launcher_id": event.sender.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
if plugin_event.is_prevented_default():
return
self.on_person_message(event)
self.go(stranger_message_handler, event)
@self.bot.on(GroupMessage)
async def on_group_message(event: GroupMessage):
def group_message_handler(event: GroupMessage):
# 触发事件
args = {
"launcher_type": "group",
"launcher_id": event.group.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.GroupMessageReceived, **args)
if plugin_event.is_prevented_default():
return
self.on_group_message(event)
self.go(group_message_handler, event)
def unsubscribe_all():
"""取消所有订阅
用于在热重载流程中卸载所有事件处理器
"""
assert isinstance(self.bot, Mirai)
bus = self.bot.bus
assert isinstance(bus, mirai.models.bus.ModelEventBus)
bus.unsubscribe(FriendMessage, on_friend_message)
bus.unsubscribe(StrangerMessage, on_stranger_message)
bus.unsubscribe(GroupMessage, on_group_message)
self.unsubscribe_all = unsubscribe_all
def go(self, func, *args, **kwargs):
self.pool.submit(func, *args, **kwargs)
def first_time_init(self, mirai_http_api_config: dict):
"""热重载后不再运行此函数"""
if 'adapter' not in mirai_http_api_config or mirai_http_api_config['adapter'] == "WebSocketAdapter":
bot = Mirai(
qq=mirai_http_api_config['qq'],
adapter=WebSocketAdapter(
verify_key=mirai_http_api_config['verifyKey'],
host=mirai_http_api_config['host'],
port=mirai_http_api_config['port']
)
)
elif mirai_http_api_config['adapter'] == "HTTPAdapter":
bot = Mirai(
qq=mirai_http_api_config['qq'],
adapter=HTTPAdapter(
verify_key=mirai_http_api_config['verifyKey'],
host=mirai_http_api_config['host'],
port=mirai_http_api_config['port']
)
)
else:
raise Exception("未知的适配器类型")
self.bot = bot
def send(self, event, msg, check_quote=True):
config = pkg.utils.context.get_config()
asyncio.run(
self.bot.send(event, msg, quote=True if hasattr(config,
"quote_origin") and config.quote_origin and check_quote else False))
# 私聊消息处理
def on_person_message(self, event: MessageEvent):
import config
reply = ''
if event.sender.id == self.bot.qq:
pass
else:
if Image in event.message_chain:
pass
else:
# 超时则重试,重试超过次数则放弃
failed = 0
for i in range(self.retry):
try:
@func_set_timeout(config.process_message_timeout)
def time_ctrl_wrapper():
reply = processor.process_message('person', event.sender.id, str(event.message_chain),
event.message_chain,
event.sender.id)
return reply
reply = time_ctrl_wrapper()
break
except FunctionTimedOut:
logging.warning("person_{}: 超时,重试中({})".format(event.sender.id, i))
pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
if "person_{}".format(event.sender.id) in pkg.qqbot.process.processing:
pkg.qqbot.process.processing.remove('person_{}'.format(event.sender.id))
failed += 1
continue
if failed == self.retry:
pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
self.notify_admin("{} 请求超时".format("person_{}".format(event.sender.id)))
reply = ["[bot]err:请求超时"]
if reply:
return self.send(event, reply, check_quote=False)
# 群消息处理
def on_group_message(self, event: GroupMessage):
import config
reply = ''
def process(text=None) -> str:
replys = ""
if At(self.bot.qq) in event.message_chain:
event.message_chain.remove(At(self.bot.qq))
# 超时则重试,重试超过次数则放弃
failed = 0
for i in range(self.retry):
try:
@func_set_timeout(config.process_message_timeout)
def time_ctrl_wrapper():
replys = processor.process_message('group', event.group.id,
str(event.message_chain).strip() if text is None else text,
event.message_chain,
event.sender.id)
return replys
replys = time_ctrl_wrapper()
break
except FunctionTimedOut:
logging.warning("group_{}: 超时,重试中({})".format(event.group.id, i))
pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock()
if "group_{}".format(event.group.id) in pkg.qqbot.process.processing:
pkg.qqbot.process.processing.remove('group_{}'.format(event.group.id))
failed += 1
continue
if failed == self.retry:
pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock()
self.notify_admin("{} 请求超时".format("group_{}".format(event.group.id)))
replys = ["[bot]err:请求超时"]
return replys
if Image in event.message_chain:
pass
else:
if At(self.bot.qq) in event.message_chain and response_at():
# 直接调用
reply = process()
else:
check, result = check_response_rule(str(event.message_chain).strip())
if check:
reply = process(result.strip())
# 检查是否随机响应
elif random_responding():
logging.info("随机响应group_{}消息".format(event.group.id))
reply = process()
if reply:
return self.send(event, reply)
# 通知系统管理员
def notify_admin(self, message: str):
config = pkg.utils.context.get_config()
if hasattr(config, "admin_qq") and config.admin_qq != 0 and config.admin_qq != []:
logging.info("通知管理员:{}".format(message))
if type(config.admin_qq) == int:
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message))
threading.Thread(target=asyncio.run, args=(send_task,)).start()
else:
for adm in config.admin_qq:
send_task = self.bot.send_friend_message(adm, "[bot]{}".format(message))
threading.Thread(target=asyncio.run, args=(send_task,)).start()
def notify_admin_message_chain(self, message):
config = pkg.utils.context.get_config()
if hasattr(config, "admin_qq") and config.admin_qq != 0 and config.admin_qq != []:
logging.info("通知管理员:{}".format(message))
if type(config.admin_qq) == int:
send_task = self.bot.send_friend_message(config.admin_qq, message)
threading.Thread(target=asyncio.run, args=(send_task,)).start()
else:
for adm in config.admin_qq:
send_task = self.bot.send_friend_message(adm, message)
threading.Thread(target=asyncio.run, args=(send_task,)).start()

View File

@ -1,130 +0,0 @@
# 普通消息处理模块
import logging
import time
import openai
import pkg.utils.context
import pkg.openai.session
import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
import pkg.qqbot.blob as blob
def handle_exception(notify_admin: str = "", set_reply: str = "") -> list:
"""处理异常当notify_admin不为空时会通知管理员返回通知用户的消息"""
import config
pkg.utils.context.get_qqbot_manager().notify_admin(notify_admin)
if hasattr(config, 'hide_exce_info_to_user') and config.hide_exce_info_to_user:
if hasattr(config, 'alter_tip_message'):
return [config.alter_tip_message] if config.alter_tip_message else []
else:
return ["[bot]出错了,请重试或联系管理员"]
else:
return [set_reply]
def process_normal_message(text_message: str, mgr, config, launcher_type: str,
launcher_id: int, sender_id: int) -> list:
session_name = f"{launcher_type}_{launcher_id}"
logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + (
"..." if len(text_message) > 20 else "")))
session = pkg.openai.session.get_session(session_name)
unexpected_exception_times = 0
max_unexpected_exception_times = 3
reply = []
while True:
if unexpected_exception_times >= max_unexpected_exception_times:
reply = handle_exception(notify_admin=f"{session_name},多次尝试失败。", set_reply=f"[bot]多次尝试失败,请重试或联系管理员")
break
try:
prefix = "[GPT]" if hasattr(config, "show_prefix") and config.show_prefix else ""
text = session.append(text_message)
# 触发插件事件
args = {
"launcher_type": launcher_type,
"launcher_id": launcher_id,
"sender_id": sender_id,
"session": session,
"prefix": prefix,
"response_text": text
}
event = pkg.plugin.host.emit(plugin_models.NormalMessageResponded, **args)
if event.get_return_value("prefix") is not None:
prefix = event.get_return_value("prefix")
if event.get_return_value("reply") is not None:
reply = event.get_return_value("reply")
if not event.is_prevented_default():
reply = blob.check_text(prefix + text)
except openai.error.APIConnectionError as e:
err_msg = str(e)
if err_msg.__contains__('Error communicating with OpenAI'):
reply = handle_exception("{}会话调用API失败:{}\n请尝试关闭网络代理来解决此问题。".format(session_name, e),
"[bot]err:调用API失败请重试或联系管理员或等待修复")
else:
reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), "[bot]err:调用API失败请重试或联系管理员或等待修复")
except openai.error.RateLimitError as e:
logging.debug(type(e))
logging.debug(e.error['message'])
if 'message' in e.error and e.error['message'].__contains__('You exceeded your current quota'):
# 尝试切换api-key
current_key_name = pkg.utils.context.get_openai_manager().key_mgr.get_key_name(
pkg.utils.context.get_openai_manager().key_mgr.using_key
)
pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded()
# 触发插件事件
args = {
'key_name': current_key_name,
'usage': pkg.utils.context.get_openai_manager().audit_mgr
.get_usage(pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()),
'exceeded_keys': pkg.utils.context.get_openai_manager().key_mgr.exceeded,
}
event = plugin_host.emit(plugin_models.KeyExceeded, **args)
if not event.is_prevented_default():
switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch()
if not switched:
reply = handle_exception(
"api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key如果你认为这是误判请尝试重启程序。".format(
current_key_name), "[bot]err:API调用额度超额请联系管理员或等待修复")
else:
openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key()
mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name))
reply = ["[bot]err:API调用额度超额已自动切换请重新发送消息"]
continue
elif 'message' in e.error and e.error['message'].__contains__('You can retry your request'):
# 重试
unexpected_exception_times += 1
continue
elif 'message' in e.error and e.error['message']\
.__contains__('The server had an error while processing your request'):
# 重试
unexpected_exception_times += 1
continue
else:
reply = handle_exception("{}会话调用API失败:{}".format(session_name, e),
"[bot]err:RateLimitError,请重试或联系作者,或等待修复")
except openai.error.InvalidRequestError as e:
reply = handle_exception("{}API调用参数错误:{}\n\n这可能是由于config.py中的prompt_submit_length参数或"
"completion_api_params中的max_tokens参数数值过大导致的请尝试将其降低".format(
session_name, e), "[bot]err:API调用参数错误请联系管理员或等待修复")
except openai.error.ServiceUnavailableError as e:
reply = handle_exception("{}API调用服务不可用:{}".format(session_name, e), "[bot]err:API调用服务不可用请重试或联系管理员或等待修复")
except Exception as e:
logging.exception(e)
reply = handle_exception("{}会话处理异常:{}".format(session_name, e), "[bot]err:{}".format(e))
break
return reply

View File

@ -1,168 +0,0 @@
# 此模块提供了消息处理的具体逻辑的接口
import asyncio
import time
import mirai
import logging
from mirai import MessageChain, Plain
# 这里不使用动态引入config
# 因为在这里动态引入会卡死程序
# 而此模块静态引用config与动态引入的表现一致
# 已弃用,由于超时时间现已动态使用
# import config as config_init_import
import pkg.openai.session
import pkg.openai.manager
import pkg.utils.reloader
import pkg.utils.updater
import pkg.utils.context
import pkg.qqbot.message
import pkg.qqbot.command
import pkg.qqbot.ratelimit as ratelimit
import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
import pkg.qqbot.ignore as ignore
import pkg.qqbot.banlist as banlist
processing = []
def is_admin(qq: int) -> bool:
"""兼容list和int类型的管理员判断"""
import config
if type(config.admin_qq) == list:
return qq in config.admin_qq
else:
return qq == config.admin_qq
def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: MessageChain,
sender_id: int) -> MessageChain:
global processing
mgr = pkg.utils.context.get_qqbot_manager()
reply = []
session_name = "{}_{}".format(launcher_type, launcher_id)
# 检查发送方是否被禁用
if banlist.is_banned(launcher_type, launcher_id, sender_id):
logging.info("根据禁用列表忽略{}_{}的消息".format(launcher_type, launcher_id))
return []
if ignore.ignore(text_message):
logging.info("根据忽略规则忽略消息: {}".format(text_message))
return []
# 检查是否被禁言
if launcher_type == 'group':
result = mgr.bot.member_info(target=launcher_id, member_id=mgr.bot.qq).get()
result = asyncio.run(result)
if result.mute_time_remaining > 0:
logging.info("机器人被禁言,跳过消息处理(group_{},剩余{}s)".format(launcher_id,
result.mute_time_remaining))
return reply
import config
if hasattr(config, 'income_msg_check') and config.income_msg_check:
if mgr.reply_filter.is_illegal(text_message):
return MessageChain(Plain("[bot] 你的提问中有不合适的内容, 请更换措辞~"))
pkg.openai.session.get_session(session_name).acquire_response_lock()
text_message = text_message.strip()
# 处理消息
try:
if session_name in processing:
pkg.openai.session.get_session(session_name).release_response_lock()
return MessageChain([Plain("[bot]err:正在处理中,请稍后再试")])
config = pkg.utils.context.get_config()
processing.append(session_name)
try:
if text_message.startswith('!') or text_message.startswith(""): # 指令
# 触发插件事件
args = {
'launcher_type': launcher_type,
'launcher_id': launcher_id,
'sender_id': sender_id,
'command': text_message[1:].strip().split(' ')[0],
'params': text_message[1:].strip().split(' ')[1:],
'text_message': text_message,
'is_admin': is_admin(sender_id),
}
event = plugin_host.emit(plugin_models.PersonCommandSent
if launcher_type == 'person'
else plugin_models.GroupCommandSent, **args)
if event.get_return_value("alter") is not None:
text_message = event.get_return_value("alter")
# 取出插件提交的返回值赋值给reply
if event.get_return_value("reply") is not None:
reply = event.get_return_value("reply")
if not event.is_prevented_default():
reply = pkg.qqbot.command.process_command(session_name, text_message,
mgr, config, launcher_type, launcher_id, sender_id, is_admin(sender_id))
else: # 消息
# 限速丢弃检查
# print(ratelimit.__crt_minute_usage__[session_name])
if hasattr(config, "rate_limitation") and config.rate_limit_strategy == "drop":
if ratelimit.is_reach_limit(session_name):
logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message))
return MessageChain(["[bot]"+config.rate_limit_drop_tip]) if hasattr(config, "rate_limit_drop_tip") and config.rate_limit_drop_tip != "" else []
before = time.time()
# 触发插件事件
args = {
"launcher_type": launcher_type,
"launcher_id": launcher_id,
"sender_id": sender_id,
"text_message": text_message,
}
event = plugin_host.emit(plugin_models.PersonNormalMessageReceived
if launcher_type == 'person'
else plugin_models.GroupNormalMessageReceived, **args)
if event.get_return_value("alter") is not None:
text_message = event.get_return_value("alter")
# 取出插件提交的返回值赋值给reply
if event.get_return_value("reply") is not None:
reply = event.get_return_value("reply")
if not event.is_prevented_default():
reply = pkg.qqbot.message.process_normal_message(text_message,
mgr, config, launcher_type, launcher_id, sender_id)
# 限速等待时间
if hasattr(config, "rate_limitation") and config.rate_limit_strategy == "wait":
time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before))
if hasattr(config, "rate_limitation"):
ratelimit.add_usage(session_name)
if reply is not None and len(reply) > 0 and (type(reply[0]) == str or type(reply[0]) == mirai.Plain):
if type(reply[0]) == mirai.Plain:
reply[0] = reply[0].text
logging.info(
"回复[{}]文字消息:{}".format(session_name,
reply[0][:min(100, len(reply[0]))] + (
"..." if len(reply[0]) > 100 else "")))
reply = [mgr.reply_filter.process(reply[0])]
else:
logging.info("回复[{}]消息".format(session_name))
finally:
processing.remove(session_name)
finally:
pkg.openai.session.get_session(session_name).release_response_lock()
return MessageChain(reply)

View File

@ -1,86 +0,0 @@
# 限速相关模块
import time
import logging
import threading
__crt_minute_usage__ = {}
"""当前分钟每个会话的对话次数"""
__timer_thr__: threading.Thread = None
def add_usage(session_name: str):
"""增加会话的对话次数"""
global __crt_minute_usage__
if session_name in __crt_minute_usage__:
__crt_minute_usage__[session_name] += 1
else:
__crt_minute_usage__[session_name] = 1
def start_timer():
"""启动定时器"""
global __timer_thr__
__timer_thr__ = threading.Thread(target=run_timer, daemon=True)
__timer_thr__.start()
def run_timer():
"""启动定时器,每分钟清空一次对话次数"""
global __crt_minute_usage__
global __timer_thr__
# 等待直到整分钟
time.sleep(60 - time.time() % 60)
while True:
if __timer_thr__ != threading.current_thread():
break
logging.debug("清空当前分钟的对话次数")
__crt_minute_usage__ = {}
time.sleep(60)
def get_usage(session_name: str) -> int:
"""获取会话的对话次数"""
global __crt_minute_usage__
if session_name in __crt_minute_usage__:
return __crt_minute_usage__[session_name]
else:
return 0
def get_rest_wait_time(session_name: str, spent: float) -> float:
"""获取会话此回合的剩余等待时间"""
global __crt_minute_usage__
import config
if not hasattr(config, 'rate_limitation'):
return 0
min_seconds_per_round = 60.0 / config.rate_limitation
if session_name in __crt_minute_usage__:
return max(0, min_seconds_per_round - spent)
else:
return 0
def is_reach_limit(session_name: str) -> bool:
"""判断会话是否超过限制"""
global __crt_minute_usage__
import config
if not hasattr(config, 'rate_limitation'):
return False
if session_name in __crt_minute_usage__:
return __crt_minute_usage__[session_name] >= config.rate_limitation
else:
return False
start_timer()