mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
chore: 删除已弃用的文件
This commit is contained in:
parent
b5924bb34f
commit
238c55a40e
|
@ -1,7 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from . import model as file_model
|
||||
from ..utils import context
|
||||
from .impls import pymodule, json as json_file
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import logging
|
|||
import asyncio
|
||||
|
||||
from ..qqbot import manager as qqbot_mgr
|
||||
from ..openai import manager as openai_mgr
|
||||
from ..openai.session import sessionmgr as llm_session_mgr
|
||||
from ..openai.requester import modelmgr as llm_model_mgr
|
||||
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
|
||||
|
@ -21,8 +20,6 @@ from ..pipeline import stagemgr
|
|||
class Application:
|
||||
im_mgr: qqbot_mgr.QQBotManager = None
|
||||
|
||||
llm_mgr: openai_mgr.OpenAIInteract = None
|
||||
|
||||
cmd_mgr: cmdmgr.CommandManager = None
|
||||
|
||||
sess_mgr: llm_session_mgr.SessionManager = None
|
||||
|
|
|
@ -14,7 +14,6 @@ from . import controller
|
|||
from ..pipeline import stagemgr
|
||||
from ..audit import identifier
|
||||
from ..database import manager as db_mgr
|
||||
from ..openai import manager as llm_mgr
|
||||
from ..openai.session import sessionmgr as llm_session_mgr
|
||||
from ..openai.requester import modelmgr as llm_model_mgr
|
||||
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
|
||||
|
@ -107,9 +106,6 @@ async def make_app() -> app.Application:
|
|||
db_mgr_inst.initialize_database()
|
||||
ap.db_mgr = db_mgr_inst
|
||||
|
||||
llm_mgr_inst = llm_mgr.OpenAIInteract(ap)
|
||||
ap.llm_mgr = llm_mgr_inst
|
||||
|
||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||
await cmd_mgr_inst.initialize()
|
||||
ap.cmd_mgr = cmd_mgr_inst
|
||||
|
@ -130,7 +126,7 @@ async def make_app() -> app.Application:
|
|||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap)
|
||||
im_mgr_inst = im_mgr.QQBotManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.im_mgr = im_mgr_inst
|
||||
|
||||
|
|
|
@ -1,134 +0,0 @@
|
|||
# 多情景预设值管理
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ..utils import context
|
||||
|
||||
# __current__ = "default"
|
||||
# """当前默认使用的情景预设的名称
|
||||
|
||||
# 由管理员使用`!default <名称>`命令切换
|
||||
# """
|
||||
|
||||
# __prompts_from_files__ = {}
|
||||
# """从文件中读取的情景预设值"""
|
||||
|
||||
# __scenario_from_files__ = {}
|
||||
|
||||
|
||||
class ScenarioMode:
|
||||
"""情景预设模式抽象类"""
|
||||
|
||||
using_prompt_name = "default"
|
||||
"""新session创建时使用的prompt名称"""
|
||||
|
||||
prompts: dict[str, list] = {}
|
||||
|
||||
def __init__(self):
|
||||
logging.debug("prompts: {}".format(self.prompts))
|
||||
|
||||
def list(self) -> dict[str, list]:
|
||||
"""获取所有情景预设的名称及内容"""
|
||||
return self.prompts
|
||||
|
||||
def get_prompt(self, name: str) -> tuple[list, str]:
|
||||
"""获取指定情景预设的名称及内容"""
|
||||
for key in self.prompts:
|
||||
if key.startswith(name):
|
||||
return self.prompts[key], key
|
||||
raise Exception("没有找到情景预设: {}".format(name))
|
||||
|
||||
def set_using_name(self, name: str) -> str:
|
||||
"""设置默认情景预设"""
|
||||
for key in self.prompts:
|
||||
if key.startswith(name):
|
||||
self.using_prompt_name = key
|
||||
return key
|
||||
raise Exception("没有找到情景预设: {}".format(name))
|
||||
|
||||
def get_full_name(self, name: str) -> str:
|
||||
"""获取完整的情景预设名称"""
|
||||
for key in self.prompts:
|
||||
if key.startswith(name):
|
||||
return key
|
||||
raise Exception("没有找到情景预设: {}".format(name))
|
||||
|
||||
def get_using_name(self) -> str:
|
||||
"""获取默认情景预设"""
|
||||
return self.using_prompt_name
|
||||
|
||||
|
||||
class NormalScenarioMode(ScenarioMode):
|
||||
"""普通情景预设模式"""
|
||||
|
||||
def __init__(self):
|
||||
config = context.get_config_manager().data
|
||||
|
||||
# 加载config中的default_prompt值
|
||||
if type(config['default_prompt']) == str:
|
||||
self.using_prompt_name = "default"
|
||||
self.prompts = {"default": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": config['default_prompt']
|
||||
}
|
||||
]}
|
||||
|
||||
elif type(config['default_prompt']) == dict:
|
||||
for key in config['default_prompt']:
|
||||
self.prompts[key] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": config['default_prompt'][key]
|
||||
}
|
||||
]
|
||||
|
||||
# 从prompts/目录下的文件中载入
|
||||
# 遍历文件
|
||||
for file in os.listdir("prompts"):
|
||||
with open(os.path.join("prompts", file), encoding="utf-8") as f:
|
||||
self.prompts[file] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f.read()
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class FullScenarioMode(ScenarioMode):
|
||||
"""完整情景预设模式"""
|
||||
|
||||
def __init__(self):
|
||||
"""从json读取所有"""
|
||||
# 遍历scenario/目录下的所有文件,以文件名为键,文件内容中的prompt为值
|
||||
for file in os.listdir("scenario"):
|
||||
if file == "default-template.json":
|
||||
continue
|
||||
with open(os.path.join("scenario", file), encoding="utf-8") as f:
|
||||
self.prompts[file] = json.load(f)["prompt"]
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
scenario_mode_mapping = {}
|
||||
"""情景预设模式名称与对象的映射"""
|
||||
|
||||
|
||||
def register_all():
|
||||
"""注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载"""
|
||||
global scenario_mode_mapping
|
||||
scenario_mode_mapping = {
|
||||
"normal": NormalScenarioMode(),
|
||||
"full_scenario": FullScenarioMode()
|
||||
}
|
||||
|
||||
|
||||
def mode_inst() -> ScenarioMode:
|
||||
"""获取指定名称的情景预设模式对象"""
|
||||
config = context.get_config_manager().data
|
||||
|
||||
if config['preset_mode'] == "default":
|
||||
config['preset_mode'] = "normal"
|
||||
|
||||
return scenario_mode_mapping[config['preset_mode']]
|
|
@ -1,46 +0,0 @@
|
|||
# 封装了function calling的一些支持函数
|
||||
import logging
|
||||
|
||||
from ..plugin import host
|
||||
|
||||
|
||||
class ContentFunctionNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_func_schema_list() -> list:
|
||||
"""从plugin包中的函数结构中获取并处理成受GPT支持的格式"""
|
||||
if not host.__enable_content_functions__:
|
||||
return []
|
||||
|
||||
schemas = []
|
||||
|
||||
for func in host.__callable_functions__:
|
||||
if func['enabled']:
|
||||
fun_cp = func.copy()
|
||||
|
||||
del fun_cp['enabled']
|
||||
|
||||
schemas.append(fun_cp)
|
||||
|
||||
return schemas
|
||||
|
||||
def get_func(name: str) -> callable:
|
||||
if name not in host.__function_inst_map__:
|
||||
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))
|
||||
|
||||
return host.__function_inst_map__[name]
|
||||
|
||||
def get_func_schema(name: str) -> dict:
|
||||
for func in host.__callable_functions__:
|
||||
if func['name'] == name:
|
||||
return func
|
||||
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))
|
||||
|
||||
def execute_function(name: str, kwargs: dict) -> any:
|
||||
"""执行函数调用"""
|
||||
|
||||
logging.debug("executing function: name='{}', kwargs={}".format(name, kwargs))
|
||||
|
||||
func = get_func(name)
|
||||
return func(**kwargs)
|
|
@ -1,103 +0,0 @@
|
|||
# 此模块提供了维护api-key的各种功能
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
|
||||
|
||||
class KeysManager:
|
||||
api_key = {}
|
||||
"""所有api-key"""
|
||||
|
||||
using_key = ""
|
||||
"""当前使用的api-key"""
|
||||
|
||||
alerted = []
|
||||
"""已提示过超额的key
|
||||
|
||||
记录在此以避免重复提示
|
||||
"""
|
||||
|
||||
exceeded = []
|
||||
"""已超额的key
|
||||
|
||||
供自动切换功能识别
|
||||
"""
|
||||
|
||||
def get_using_key(self):
|
||||
return self.using_key
|
||||
|
||||
def get_using_key_md5(self):
|
||||
return hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
|
||||
|
||||
def __init__(self, api_key):
|
||||
|
||||
assert type(api_key) == dict
|
||||
self.api_key = api_key
|
||||
# 从usage中删除未加载的api-key的记录
|
||||
# 不删了,也许会运行时添加曾经有记录的api-key
|
||||
|
||||
self.auto_switch()
|
||||
|
||||
def auto_switch(self) -> tuple[bool, str]:
|
||||
"""尝试切换api-key
|
||||
|
||||
Returns:
|
||||
是否切换成功, 切换后的api-key的别名
|
||||
"""
|
||||
|
||||
index = 0
|
||||
|
||||
for key_name in self.api_key:
|
||||
if self.api_key[key_name] == self.using_key:
|
||||
break
|
||||
|
||||
index += 1
|
||||
|
||||
# 从当前key开始向后轮询
|
||||
start_index = index
|
||||
index += 1
|
||||
if index >= len(self.api_key):
|
||||
index = 0
|
||||
|
||||
while index != start_index:
|
||||
|
||||
key_name = list(self.api_key.keys())[index]
|
||||
|
||||
if self.api_key[key_name] not in self.exceeded:
|
||||
self.using_key = self.api_key[key_name]
|
||||
|
||||
logging.debug("使用api-key:" + key_name)
|
||||
|
||||
# 触发插件事件
|
||||
args = {
|
||||
"key_name": key_name,
|
||||
"key_list": self.api_key.keys()
|
||||
}
|
||||
_ = plugin_host.emit(plugin_models.KeySwitched, **args)
|
||||
|
||||
return True, key_name
|
||||
|
||||
index += 1
|
||||
if index >= len(self.api_key):
|
||||
index = 0
|
||||
|
||||
self.using_key = list(self.api_key.values())[start_index]
|
||||
logging.debug("使用api-key:" + list(self.api_key.keys())[start_index])
|
||||
|
||||
return False, list(self.api_key.keys())[start_index]
|
||||
|
||||
def add(self, key_name, key):
|
||||
self.api_key[key_name] = key
|
||||
|
||||
def set_current_exceeded(self):
|
||||
"""设置当前使用的api-key使用量超限"""
|
||||
self.exceeded.append(self.using_key)
|
||||
|
||||
def get_key_name(self, api_key):
|
||||
"""根据api-key获取其别名"""
|
||||
for key_name in self.api_key:
|
||||
if self.api_key[key_name] == api_key:
|
||||
return key_name
|
||||
return ""
|
|
@ -1,108 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import openai
|
||||
from openai.types import images_response
|
||||
|
||||
from ..openai import keymgr
|
||||
from ..utils import context
|
||||
from ..audit import gatherer
|
||||
from ..openai import modelmgr
|
||||
from ..openai.api import model as api_model
|
||||
from ..core import app
|
||||
|
||||
|
||||
class OpenAIInteract:
|
||||
"""OpenAI 接口封装
|
||||
|
||||
将文字接口和图片接口封装供调用方使用
|
||||
"""
|
||||
|
||||
key_mgr: keymgr.KeysManager = None
|
||||
|
||||
audit_mgr: gatherer.DataGatherer = None
|
||||
|
||||
default_image_api_params = {
|
||||
"size": "256x256",
|
||||
}
|
||||
|
||||
client: openai.Client = None
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
|
||||
cfg= ap.cfg_mgr.data
|
||||
api_key = cfg['openai_config']['api_key']
|
||||
|
||||
self.key_mgr = keymgr.KeysManager(api_key)
|
||||
self.audit_mgr = gatherer.DataGatherer()
|
||||
|
||||
# 配置OpenAI proxy
|
||||
openai.proxies = None # 先重置,因为重载后可能需要清除proxy
|
||||
if "http_proxy" in cfg['openai_config'] and cfg['openai_config']["http_proxy"] is not None:
|
||||
openai.proxies = {
|
||||
"http": cfg['openai_config']["http_proxy"],
|
||||
"https": cfg['openai_config']["http_proxy"]
|
||||
}
|
||||
|
||||
# 配置openai api_base
|
||||
if "reverse_proxy" in cfg['openai_config'] and cfg['openai_config']["reverse_proxy"] is not None:
|
||||
logging.debug("设置反向代理: "+cfg['openai_config']['reverse_proxy'])
|
||||
openai.base_url = cfg['openai_config']["reverse_proxy"]
|
||||
|
||||
|
||||
self.client = openai.Client(
|
||||
api_key=self.key_mgr.get_using_key(),
|
||||
base_url=openai.base_url
|
||||
)
|
||||
|
||||
context.set_openai_manager(self)
|
||||
|
||||
def request_completion(self, messages: list):
|
||||
"""请求补全接口回复=
|
||||
"""
|
||||
# 选择接口请求类
|
||||
config = context.get_config_manager().data
|
||||
|
||||
request: api_model.RequestBase
|
||||
|
||||
model: str = config['completion_api_params']['model']
|
||||
|
||||
cp_parmas = config['completion_api_params'].copy()
|
||||
del cp_parmas['model']
|
||||
|
||||
request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas)
|
||||
|
||||
# 请求接口
|
||||
for resp in request:
|
||||
|
||||
if resp['usage']['total_tokens'] > 0:
|
||||
self.audit_mgr.report_text_model_usage(
|
||||
model,
|
||||
resp['usage']['total_tokens']
|
||||
)
|
||||
|
||||
yield resp
|
||||
|
||||
def request_image(self, prompt) -> images_response.ImagesResponse:
|
||||
"""请求图片接口回复
|
||||
|
||||
Parameters:
|
||||
prompt (str): 提示语
|
||||
|
||||
Returns:
|
||||
dict: 响应
|
||||
"""
|
||||
config = context.get_config_manager().data
|
||||
params = config['image_api_params']
|
||||
|
||||
response = self.client.images.generate(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
**params
|
||||
)
|
||||
|
||||
self.audit_mgr.report_image_model_usage(params['size'])
|
||||
|
||||
return response
|
||||
|
|
@ -1,504 +0,0 @@
|
|||
"""主线使用的会话管理模块
|
||||
|
||||
每个人、每个群单独一个session,session内部保留了对话的上下文,
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
|
||||
from ..openai import manager as openai_manager
|
||||
from ..openai import modelmgr as openai_modelmgr
|
||||
from ..database import manager as database_manager
|
||||
from ..utils import context as context
|
||||
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
|
||||
# 运行时保存的所有session
|
||||
sessions = {}
|
||||
|
||||
|
||||
class SessionOfflineStatus:
|
||||
ON_GOING = 'on_going'
|
||||
EXPLICITLY_CLOSED = 'explicitly_closed'
|
||||
|
||||
|
||||
# 从数据加载session
|
||||
def load_sessions():
|
||||
"""从数据库加载sessions"""
|
||||
|
||||
global sessions
|
||||
|
||||
db_inst = context.get_database_manager()
|
||||
|
||||
session_data = db_inst.load_valid_sessions()
|
||||
|
||||
for session_name in session_data:
|
||||
logging.debug('加载session: {}'.format(session_name))
|
||||
|
||||
temp_session = Session(session_name)
|
||||
temp_session.name = session_name
|
||||
temp_session.create_timestamp = session_data[session_name]['create_timestamp']
|
||||
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
|
||||
|
||||
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
|
||||
temp_session.token_counts = json.loads(session_data[session_name]['token_counts'])
|
||||
|
||||
temp_session.default_prompt = json.loads(session_data[session_name]['default_prompt']) if \
|
||||
session_data[session_name]['default_prompt'] else []
|
||||
|
||||
sessions[session_name] = temp_session
|
||||
|
||||
|
||||
# 获取指定名称的session,如果不存在则创建一个新的
|
||||
def get_session(session_name: str) -> 'Session':
|
||||
global sessions
|
||||
if session_name not in sessions:
|
||||
sessions[session_name] = Session(session_name)
|
||||
return sessions[session_name]
|
||||
|
||||
|
||||
def dump_session(session_name: str):
|
||||
global sessions
|
||||
if session_name in sessions:
|
||||
assert isinstance(sessions[session_name], Session)
|
||||
sessions[session_name].persistence()
|
||||
del sessions[session_name]
|
||||
|
||||
|
||||
# 通用的OpenAI API交互session
|
||||
# session内部保留了对话的上下文,
|
||||
# 收到用户消息后,将上下文提交给OpenAI API生成回复
|
||||
class Session:
|
||||
name = ''
|
||||
|
||||
prompt = []
|
||||
"""使用list来保存会话中的回合"""
|
||||
|
||||
default_prompt = []
|
||||
"""本session的默认prompt"""
|
||||
|
||||
create_timestamp = 0
|
||||
"""会话创建时间"""
|
||||
|
||||
last_interact_timestamp = 0
|
||||
"""上次交互(产生回复)时间"""
|
||||
|
||||
just_switched_to_exist_session = False
|
||||
|
||||
response_lock = None
|
||||
|
||||
# 加锁
|
||||
def acquire_response_lock(self):
|
||||
logging.debug('{},lock acquire,{}'.format(self.name, self.response_lock))
|
||||
self.response_lock.acquire()
|
||||
logging.debug('{},lock acquire successfully,{}'.format(self.name, self.response_lock))
|
||||
|
||||
# 释放锁
|
||||
def release_response_lock(self):
|
||||
if self.response_lock.locked():
|
||||
logging.debug('{},lock release,{}'.format(self.name, self.response_lock))
|
||||
self.response_lock.release()
|
||||
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
|
||||
|
||||
# 从配置文件获取会话预设信息
|
||||
def get_default_prompt(self, use_default: str = None):
|
||||
import pkg.openai.dprompt as dprompt
|
||||
|
||||
if use_default is None:
|
||||
use_default = dprompt.mode_inst().get_using_name()
|
||||
|
||||
current_default_prompt, _ = dprompt.mode_inst().get_prompt(use_default)
|
||||
return current_default_prompt
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.create_timestamp = int(time.time())
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
self.prompt = []
|
||||
self.token_counts = []
|
||||
self.schedule()
|
||||
|
||||
self.response_lock = threading.Lock()
|
||||
|
||||
self.default_prompt = self.get_default_prompt()
|
||||
logging.debug("prompt is: {}".format(self.default_prompt))
|
||||
|
||||
# 设定检查session最后一次对话是否超过过期时间的计时器
|
||||
def schedule(self):
|
||||
threading.Thread(target=self.expire_check_timer_loop, args=(self.create_timestamp,)).start()
|
||||
|
||||
# 检查session是否已经过期
|
||||
def expire_check_timer_loop(self, create_timestamp: int):
|
||||
global sessions
|
||||
while True:
|
||||
time.sleep(60)
|
||||
|
||||
# 不是此session已更换,退出
|
||||
if self.create_timestamp != create_timestamp or self not in sessions.values():
|
||||
return
|
||||
|
||||
config = context.get_config_manager().data
|
||||
if int(time.time()) - self.last_interact_timestamp > config['session_expire_time']:
|
||||
logging.info('session {} 已过期'.format(self.name))
|
||||
|
||||
# 触发插件事件
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'session': self,
|
||||
'session_expire_time': config['session_expire_time']
|
||||
}
|
||||
event = plugin_host.emit(plugin_models.SessionExpired, **args)
|
||||
if event.is_prevented_default():
|
||||
return
|
||||
|
||||
self.reset(expired=True, schedule_new=False)
|
||||
|
||||
# 删除此session
|
||||
del sessions[self.name]
|
||||
return
|
||||
|
||||
# 请求回复
|
||||
# 这个函数是阻塞的
|
||||
def query(self, text: str=None) -> tuple[str, str, list[str]]:
|
||||
"""向session中添加一条消息,返回接口回复
|
||||
|
||||
Args:
|
||||
text (str): 用户消息
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (接口回复, finish_reason, 已调用的函数列表)
|
||||
"""
|
||||
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
|
||||
# 触发插件事件
|
||||
if not self.prompt:
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'session': self,
|
||||
'default_prompt': self.default_prompt,
|
||||
}
|
||||
|
||||
event = plugin_host.emit(plugin_models.SessionFirstMessageReceived, **args)
|
||||
if event.is_prevented_default():
|
||||
return None, None, None
|
||||
|
||||
config = context.get_config_manager().data
|
||||
max_length = config['prompt_submit_length']
|
||||
|
||||
local_default_prompt = self.default_prompt.copy()
|
||||
local_prompt = self.prompt.copy()
|
||||
|
||||
# 触发PromptPreProcessing事件
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'default_prompt': self.default_prompt,
|
||||
'prompt': self.prompt,
|
||||
'text_message': text,
|
||||
}
|
||||
|
||||
event = plugin_host.emit(plugin_models.PromptPreProcessing, **args)
|
||||
|
||||
if event.get_return_value('default_prompt') is not None:
|
||||
local_default_prompt = event.get_return_value('default_prompt')
|
||||
|
||||
if event.get_return_value('prompt') is not None:
|
||||
local_prompt = event.get_return_value('prompt')
|
||||
|
||||
if event.get_return_value('text_message') is not None:
|
||||
text = event.get_return_value('text_message')
|
||||
|
||||
# 裁剪messages到合适长度
|
||||
prompts, _ = self.cut_out(text, max_length, local_default_prompt, local_prompt)
|
||||
|
||||
res_text = ""
|
||||
|
||||
pending_msgs = []
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
finish_reason: str = ""
|
||||
|
||||
funcs = []
|
||||
|
||||
trace_func_calls = config['trace_function_calls']
|
||||
botmgr = context.get_qqbot_manager()
|
||||
|
||||
session_name_spt: list[str] = self.name.split("_")
|
||||
|
||||
pending_res_text = ""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# TODO 对不起,我知道这样非常非常屎山,但我之后会重构的
|
||||
for resp in context.get_openai_manager().request_completion(prompts):
|
||||
|
||||
if pending_res_text != "":
|
||||
botmgr.adapter.send_message(
|
||||
session_name_spt[0],
|
||||
session_name_spt[1],
|
||||
pending_res_text
|
||||
)
|
||||
pending_res_text = ""
|
||||
|
||||
finish_reason = resp['choices'][0]['finish_reason']
|
||||
|
||||
if resp['choices'][0]['message']['role'] == "assistant" and resp['choices'][0]['message']['content'] != None: # 包含纯文本响应
|
||||
|
||||
if not trace_func_calls:
|
||||
res_text += resp['choices'][0]['message']['content']
|
||||
else:
|
||||
res_text = resp['choices'][0]['message']['content']
|
||||
pending_res_text = resp['choices'][0]['message']['content']
|
||||
|
||||
total_tokens += resp['usage']['total_tokens']
|
||||
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": resp['choices'][0]['message']['content']
|
||||
}
|
||||
|
||||
if 'function_call' in resp['choices'][0]['message']:
|
||||
msg['function_call'] = json.dumps(resp['choices'][0]['message']['function_call'])
|
||||
|
||||
pending_msgs.append(msg)
|
||||
|
||||
if resp['choices'][0]['message']['type'] == 'function_call':
|
||||
# self.prompt.append(
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "function call: "+json.dumps(resp['choices'][0]['message']['function_call'])
|
||||
# }
|
||||
# )
|
||||
if trace_func_calls:
|
||||
botmgr.adapter.send_message(
|
||||
session_name_spt[0],
|
||||
session_name_spt[1],
|
||||
"调用函数 "+resp['choices'][0]['message']['function_call']['name'] + "..."
|
||||
)
|
||||
|
||||
total_tokens += resp['usage']['total_tokens']
|
||||
elif resp['choices'][0]['message']['type'] == 'function_return':
|
||||
# self.prompt.append(
|
||||
# {
|
||||
# "role": "function",
|
||||
# "name": resp['choices'][0]['message']['function_name'],
|
||||
# "content": json.dumps(resp['choices'][0]['message']['content'])
|
||||
# }
|
||||
# )
|
||||
|
||||
# total_tokens += resp['usage']['total_tokens']
|
||||
funcs.append(
|
||||
resp['choices'][0]['message']['function_name']
|
||||
)
|
||||
pass
|
||||
|
||||
# 向API请求补全
|
||||
# message, total_token = pkg.utils.context.get_openai_manager().request_completion(
|
||||
# prompts,
|
||||
# )
|
||||
|
||||
# 成功获取,处理回复
|
||||
# res_test = message
|
||||
res_ans = res_text.strip()
|
||||
|
||||
# 将此次对话的双方内容加入到prompt中
|
||||
# self.prompt.append({'role': 'user', 'content': text})
|
||||
# self.prompt.append({'role': 'assistant', 'content': res_ans})
|
||||
if text:
|
||||
self.prompt.append({'role': 'user', 'content': text})
|
||||
# 添加pending_msgs
|
||||
self.prompt += pending_msgs
|
||||
|
||||
# 向token_counts中添加本回合的token数量
|
||||
# self.token_counts.append(total_tokens-total_token_before_query)
|
||||
# logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts))
|
||||
|
||||
if self.just_switched_to_exist_session:
|
||||
self.just_switched_to_exist_session = False
|
||||
self.set_ongoing()
|
||||
|
||||
# 上报使用量数据
|
||||
session_type = session_name_spt[0]
|
||||
session_id = session_name_spt[1]
|
||||
|
||||
ability_provider = "QChatGPT.Text"
|
||||
usage = total_tokens
|
||||
model_name = context.get_config_manager().data['completion_api_params']['model']
|
||||
response_seconds = int(time.time() - start_time)
|
||||
retry_times = -1 # 暂不记录
|
||||
|
||||
context.get_center_v2_api().usage.post_query_record(
|
||||
session_type=session_type,
|
||||
session_id=session_id,
|
||||
query_ability_provider=ability_provider,
|
||||
usage=usage,
|
||||
model_name=model_name,
|
||||
response_seconds=response_seconds,
|
||||
retry_times=retry_times
|
||||
)
|
||||
|
||||
return res_ans if res_ans[0] != '\n' else res_ans[1:], finish_reason, funcs
|
||||
|
||||
# 删除上一回合并返回上一回合的问题
|
||||
def undo(self) -> str:
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
|
||||
# 删除最后两个消息
|
||||
if len(self.prompt) < 2:
|
||||
raise Exception('之前无对话,无法撤销')
|
||||
|
||||
question = self.prompt[-2]['content']
|
||||
self.prompt = self.prompt[:-2]
|
||||
self.token_counts = self.token_counts[:-1]
|
||||
|
||||
# 返回上一回合的问题
|
||||
return question
|
||||
|
||||
# 构建对话体
|
||||
def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) -> tuple[list, list]:
|
||||
"""将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens
|
||||
|
||||
:return: (新的prompt, 新的token_counts)
|
||||
"""
|
||||
|
||||
# 最终由三个部分组成
|
||||
# - default_prompt 情景预设固定值
|
||||
# - changable_prompts 可变部分, 此会话中的历史对话回合
|
||||
# - current_question 当前问题
|
||||
|
||||
# 包装目前的对话回合内容
|
||||
changable_prompts = []
|
||||
|
||||
use_model = context.get_config_manager().data['completion_api_params']['model']
|
||||
|
||||
ptr = len(prompt) - 1
|
||||
|
||||
# 直接从后向前扫描拼接,不管是否是整回合
|
||||
while ptr >= 0:
|
||||
if openai_modelmgr.count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
|
||||
break
|
||||
|
||||
changable_prompts.insert(0, prompt[ptr])
|
||||
|
||||
ptr -= 1
|
||||
|
||||
# 将default_prompt和changable_prompts合并
|
||||
result_prompt = default_prompt + changable_prompts
|
||||
|
||||
# 添加当前问题
|
||||
if msg:
|
||||
result_prompt.append(
|
||||
{
|
||||
'role': 'user',
|
||||
'content': msg
|
||||
}
|
||||
)
|
||||
|
||||
logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4)))
|
||||
|
||||
return result_prompt, openai_modelmgr.count_tokens(changable_prompts, use_model)
|
||||
|
||||
# 持久化session
|
||||
def persistence(self):
|
||||
if self.prompt == self.get_default_prompt():
|
||||
return
|
||||
|
||||
db_inst = context.get_database_manager()
|
||||
|
||||
name_spt = self.name.split('_')
|
||||
|
||||
subject_type = name_spt[0]
|
||||
subject_number = int(name_spt[1])
|
||||
|
||||
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
|
||||
json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts))
|
||||
|
||||
# 重置session
|
||||
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None, persist: bool = False):
|
||||
if self.prompt:
|
||||
self.persistence()
|
||||
if explicit:
|
||||
# 触发插件事件
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'session': self
|
||||
}
|
||||
|
||||
# 此事件不支持阻止默认行为
|
||||
_ = plugin_host.emit(plugin_models.SessionExplicitReset, **args)
|
||||
|
||||
context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
|
||||
|
||||
if expired:
|
||||
context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
|
||||
|
||||
if not persist: # 不要求保持default prompt
|
||||
self.default_prompt = self.get_default_prompt(use_prompt)
|
||||
self.prompt = []
|
||||
self.token_counts = []
|
||||
self.create_timestamp = int(time.time())
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
self.just_switched_to_exist_session = False
|
||||
|
||||
# self.response_lock = threading.Lock()
|
||||
|
||||
if schedule_new:
|
||||
self.schedule()
|
||||
|
||||
# 将本session的数据库状态设置为on_going
|
||||
def set_ongoing(self):
|
||||
context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
|
||||
|
||||
# 切换到上一个session
|
||||
def last_session(self):
|
||||
last_one = context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
|
||||
if last_one is None:
|
||||
return None
|
||||
else:
|
||||
self.persistence()
|
||||
|
||||
self.create_timestamp = last_one['create_timestamp']
|
||||
self.last_interact_timestamp = last_one['last_interact_timestamp']
|
||||
|
||||
self.prompt = json.loads(last_one['prompt'])
|
||||
self.token_counts = json.loads(last_one['token_counts'])
|
||||
|
||||
self.default_prompt = json.loads(last_one['default_prompt']) if last_one['default_prompt'] else []
|
||||
|
||||
self.just_switched_to_exist_session = True
|
||||
return self
|
||||
|
||||
# 切换到下一个session
|
||||
def next_session(self):
|
||||
next_one = context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
|
||||
if next_one is None:
|
||||
return None
|
||||
else:
|
||||
self.persistence()
|
||||
|
||||
self.create_timestamp = next_one['create_timestamp']
|
||||
self.last_interact_timestamp = next_one['last_interact_timestamp']
|
||||
|
||||
self.prompt = json.loads(next_one['prompt'])
|
||||
self.token_counts = json.loads(next_one['token_counts'])
|
||||
|
||||
self.default_prompt = json.loads(next_one['default_prompt']) if next_one['default_prompt'] else []
|
||||
|
||||
self.just_switched_to_exist_session = True
|
||||
return self
|
||||
|
||||
def list_history(self, capacity: int = 10, page: int = 0):
|
||||
return context.get_database_manager().list_history(self.name, capacity, page)
|
||||
|
||||
def delete_history(self, index: int) -> bool:
|
||||
return context.get_database_manager().delete_history(self.name, index)
|
||||
|
||||
def delete_all_history(self) -> bool:
|
||||
return context.get_database_manager().delete_all_history(self.name)
|
||||
|
||||
def draw_image(self, prompt: str):
|
||||
return context.get_openai_manager().request_image(prompt)
|
|
@ -1,333 +0,0 @@
|
|||
import logging
|
||||
import copy
|
||||
import pkgutil
|
||||
import traceback
|
||||
import json
|
||||
|
||||
import tips as tips_custom
|
||||
|
||||
|
||||
__command_list__ = {}
|
||||
"""命令树
|
||||
|
||||
结构:
|
||||
{
|
||||
'cmd1': {
|
||||
'description': 'cmd1 description',
|
||||
'usage': 'cmd1 usage',
|
||||
'aliases': ['cmd1 alias1', 'cmd1 alias2'],
|
||||
'privilege': 0,
|
||||
'parent': None,
|
||||
'cls': <class 'pkg.qqbot.cmds.cmd1.CommandCmd1'>,
|
||||
'sub': [
|
||||
'cmd1-1'
|
||||
]
|
||||
},
|
||||
'cmd1.cmd1-1: {
|
||||
'description': 'cmd1-1 description',
|
||||
'usage': 'cmd1-1 usage',
|
||||
'aliases': ['cmd1-1 alias1', 'cmd1-1 alias2'],
|
||||
'privilege': 0,
|
||||
'parent': 'cmd1',
|
||||
'cls': <class 'pkg.qqbot.cmds.cmd1.CommandCmd1_1'>,
|
||||
'sub': []
|
||||
},
|
||||
'cmd2': {
|
||||
'description': 'cmd2 description',
|
||||
'usage': 'cmd2 usage',
|
||||
'aliases': ['cmd2 alias1', 'cmd2 alias2'],
|
||||
'privilege': 0,
|
||||
'parent': None,
|
||||
'cls': <class 'pkg.qqbot.cmds.cmd2.CommandCmd2'>,
|
||||
'sub': [
|
||||
'cmd2-1'
|
||||
]
|
||||
},
|
||||
'cmd2.cmd2-1': {
|
||||
'description': 'cmd2-1 description',
|
||||
'usage': 'cmd2-1 usage',
|
||||
'aliases': ['cmd2-1 alias1', 'cmd2-1 alias2'],
|
||||
'privilege': 0,
|
||||
'parent': 'cmd2',
|
||||
'cls': <class 'pkg.qqbot.cmds.cmd2.CommandCmd2_1'>,
|
||||
'sub': [
|
||||
'cmd2-1-1'
|
||||
]
|
||||
},
|
||||
'cmd2.cmd2-1.cmd2-1-1': {
|
||||
'description': 'cmd2-1-1 description',
|
||||
'usage': 'cmd2-1-1 usage',
|
||||
'aliases': ['cmd2-1-1 alias1', 'cmd2-1-1 alias2'],
|
||||
'privilege': 0,
|
||||
'parent': 'cmd2.cmd2-1',
|
||||
'cls': <class 'pkg.qqbot.cmds.cmd2.CommandCmd2_1_1'>,
|
||||
'sub': []
|
||||
},
|
||||
}
|
||||
"""
|
||||
|
||||
__tree_index__: dict[str, list] = {}
|
||||
"""命令树索引
|
||||
|
||||
结构:
|
||||
{
|
||||
'pkg.qqbot.cmds.cmd1.CommandCmd1': 'cmd1', # 顶级命令
|
||||
'pkg.qqbot.cmds.cmd1.CommandCmd1_1': 'cmd1.cmd1-1', # 类名: 节点路径
|
||||
'pkg.qqbot.cmds.cmd2.CommandCmd2': 'cmd2',
|
||||
'pkg.qqbot.cmds.cmd2.CommandCmd2_1': 'cmd2.cmd2-1',
|
||||
'pkg.qqbot.cmds.cmd2.CommandCmd2_1_1': 'cmd2.cmd2-1.cmd2-1-1',
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class Context:
|
||||
"""命令执行上下文"""
|
||||
command: str
|
||||
"""顶级命令文本"""
|
||||
|
||||
crt_command: str
|
||||
"""当前子命令文本"""
|
||||
|
||||
params: list
|
||||
"""完整参数列表"""
|
||||
|
||||
crt_params: list
|
||||
"""当前子命令参数列表"""
|
||||
|
||||
session_name: str
|
||||
"""会话名"""
|
||||
|
||||
text_message: str
|
||||
"""命令完整文本"""
|
||||
|
||||
launcher_type: str
|
||||
"""命令发起者类型"""
|
||||
|
||||
launcher_id: int
|
||||
"""命令发起者ID"""
|
||||
|
||||
sender_id: int
|
||||
"""命令发送者ID"""
|
||||
|
||||
is_admin: bool
|
||||
"""[过时]命令发送者是否为管理员"""
|
||||
|
||||
privilege: int
|
||||
"""命令发送者权限等级"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
|
||||
class AbstractCommandNode:
|
||||
"""命令抽象类"""
|
||||
|
||||
parent: type
|
||||
"""父命令类"""
|
||||
|
||||
name: str
|
||||
"""命令名"""
|
||||
|
||||
description: str
|
||||
"""命令描述"""
|
||||
|
||||
usage: str
|
||||
"""命令用法"""
|
||||
|
||||
aliases: list[str]
|
||||
"""命令别名"""
|
||||
|
||||
privilege: int
|
||||
"""命令权限等级, 权限大于等于此值的用户才能执行命令"""
|
||||
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
"""命令处理函数
|
||||
|
||||
:param ctx: 命令执行上下文
|
||||
|
||||
:return: (是否执行, 回复列表(若执行))
|
||||
|
||||
若未执行,将自动以下一个参数查找并执行子命令
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def help(cls) -> str:
|
||||
"""获取命令帮助信息"""
|
||||
return '命令: {}\n描述: {}\n用法: \n{}\n别名: {}\n权限: {}'.format(
|
||||
cls.name,
|
||||
cls.description,
|
||||
cls.usage,
|
||||
', '.join(cls.aliases),
|
||||
cls.privilege
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register(
|
||||
parent: type = None,
|
||||
name: str = None,
|
||||
description: str = None,
|
||||
usage: str = None,
|
||||
aliases: list[str] = None,
|
||||
privilege: int = 0
|
||||
):
|
||||
"""注册命令
|
||||
|
||||
:param cls: 命令类
|
||||
:param name: 命令名
|
||||
:param parent: 父命令类
|
||||
"""
|
||||
global __command_list__, __tree_index__
|
||||
|
||||
def wrapper(cls):
|
||||
cls.name = name
|
||||
cls.parent = parent
|
||||
cls.description = description
|
||||
cls.usage = usage
|
||||
cls.aliases = aliases
|
||||
cls.privilege = privilege
|
||||
|
||||
logging.debug("cls: {}, name: {}, parent: {}".format(cls, name, parent))
|
||||
|
||||
if parent is None:
|
||||
# 顶级命令注册
|
||||
__command_list__[name] = {
|
||||
'description': cls.description,
|
||||
'usage': cls.usage,
|
||||
'aliases': cls.aliases,
|
||||
'privilege': cls.privilege,
|
||||
'parent': None,
|
||||
'cls': cls,
|
||||
'sub': []
|
||||
}
|
||||
# 更新索引
|
||||
__tree_index__[cls.__module__ + '.' + cls.__name__] = name
|
||||
else:
|
||||
# 获取父节点名称
|
||||
path = __tree_index__[parent.__module__ + '.' + parent.__name__]
|
||||
|
||||
parent_node = __command_list__[path]
|
||||
# 链接父子命令
|
||||
__command_list__[path]['sub'].append(name)
|
||||
# 注册子命令
|
||||
__command_list__[path + '.' + name] = {
|
||||
'description': cls.description,
|
||||
'usage': cls.usage,
|
||||
'aliases': cls.aliases,
|
||||
'privilege': cls.privilege,
|
||||
'parent': path,
|
||||
'cls': cls,
|
||||
'sub': []
|
||||
}
|
||||
# 更新索引
|
||||
__tree_index__[cls.__module__ + '.' + cls.__name__] = path + '.' + name
|
||||
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CommandPrivilegeError(Exception):
|
||||
"""命令权限不足或不存在异常"""
|
||||
pass
|
||||
|
||||
|
||||
# 传入Context对象,广搜命令树,返回执行结果
|
||||
# 若命令被处理,返回reply列表
|
||||
# 若命令未被处理,继续执行下一级命令
|
||||
# 若命令不存在,报异常
|
||||
def execute(context: Context) -> list:
|
||||
"""执行命令
|
||||
|
||||
:param ctx: 命令执行上下文
|
||||
|
||||
:return: 回复列表
|
||||
"""
|
||||
global __command_list__
|
||||
|
||||
# 拷贝ctx
|
||||
ctx: Context = copy.deepcopy(context)
|
||||
|
||||
# 从树取出顶级命令
|
||||
node = __command_list__
|
||||
|
||||
path = ctx.command
|
||||
|
||||
while True:
|
||||
try:
|
||||
node = __command_list__[path]
|
||||
logging.debug('执行命令: {}'.format(path))
|
||||
|
||||
# 检查权限
|
||||
if ctx.privilege < node['privilege']:
|
||||
raise CommandPrivilegeError(tips_custom.command_admin_message+"{}".format(path))
|
||||
|
||||
# 执行
|
||||
execed, reply = node['cls'].process(ctx)
|
||||
if execed:
|
||||
return reply
|
||||
else:
|
||||
# 删除crt_params第一个参数
|
||||
ctx.crt_command = ctx.crt_params.pop(0)
|
||||
# 下一个path
|
||||
path = path + '.' + ctx.crt_command
|
||||
except KeyError:
|
||||
traceback.print_exc()
|
||||
raise CommandPrivilegeError(tips_custom.command_err_message+"{}".format(path))
|
||||
|
||||
|
||||
def register_all():
|
||||
"""启动时调用此函数注册所有命令
|
||||
|
||||
递归处理pkg.qqbot.cmds包下及其子包下所有模块的所有继承于AbstractCommand的类
|
||||
"""
|
||||
# 模块:遍历其中的继承于AbstractCommand的类,进行注册
|
||||
# 包:递归处理包下的模块
|
||||
# 排除__开头的属性
|
||||
global __command_list__, __tree_index__
|
||||
|
||||
import pkg.qqbot.cmds
|
||||
|
||||
def walk(module, prefix, path_prefix):
|
||||
# 排除不处于pkg.qqbot.cmds中的包
|
||||
if not module.__name__.startswith('pkg.qqbot.cmds'):
|
||||
return
|
||||
|
||||
logging.debug('walk: {}, path: {}'.format(module.__name__, module.__path__))
|
||||
for item in pkgutil.iter_modules(module.__path__):
|
||||
if item.name.startswith('__'):
|
||||
continue
|
||||
|
||||
if item.ispkg:
|
||||
walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.', path_prefix + item.name + '/')
|
||||
else:
|
||||
m = __import__(module.__name__ + '.' + item.name, fromlist=[''])
|
||||
# for name, cls in inspect.getmembers(m, inspect.isclass):
|
||||
# # 检查是否为命令类
|
||||
# if cls.__module__ == m.__name__ and issubclass(cls, AbstractCommandNode) and cls != AbstractCommandNode:
|
||||
# cls.register(cls, cls.name, cls.parent)
|
||||
|
||||
walk(pkg.qqbot.cmds, '', '')
|
||||
logging.debug(__command_list__)
|
||||
|
||||
|
||||
def apply_privileges():
|
||||
"""读取cmdpriv.json并应用命令权限"""
|
||||
# 读取内容
|
||||
json_str = ""
|
||||
with open('cmdpriv.json', 'r', encoding="utf-8") as f:
|
||||
json_str = f.read()
|
||||
|
||||
data = json.loads(json_str)
|
||||
for path, priv in data.items():
|
||||
if path == 'comment':
|
||||
continue
|
||||
|
||||
if path not in __command_list__:
|
||||
continue
|
||||
|
||||
if __command_list__[path]['privilege'] != priv:
|
||||
logging.debug('应用权限: {} -> {}(default: {})'.format(path, priv, __command_list__[path]['privilege']))
|
||||
|
||||
__command_list__[path]['privilege'] = priv
|
|
@ -1,37 +0,0 @@
|
|||
import logging
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import aamgr
|
||||
from ....utils import context
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="draw",
|
||||
description="使用DALL·E生成图片",
|
||||
usage="!draw <图片提示语>",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class DrawCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
|
||||
reply = []
|
||||
|
||||
if len(ctx.params) == 0:
|
||||
reply = ["[bot]err: 未提供图片描述文字"]
|
||||
else:
|
||||
session = pkg.openai.session.get_session(ctx.session_name)
|
||||
|
||||
res = session.draw_image(" ".join(ctx.params))
|
||||
|
||||
logging.debug("draw_image result:{}".format(res))
|
||||
reply = [mirai.Image(url=res.data[0].url)]
|
||||
config = context.get_config_manager().data
|
||||
if config['include_image_description']:
|
||||
reply.append(" ".join(ctx.params))
|
||||
|
||||
return True, reply
|
|
@ -1,32 +0,0 @@
|
|||
import logging
|
||||
import json
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="func",
|
||||
description="管理内容函数",
|
||||
usage="!func",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class FuncCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
from pkg.plugin.models import host
|
||||
|
||||
reply = []
|
||||
|
||||
reply_str = "当前已加载的内容函数:\n\n"
|
||||
|
||||
logging.debug("host.__callable_functions__: {}".format(json.dumps(host.__callable_functions__, indent=4)))
|
||||
|
||||
index = 1
|
||||
for func in host.__callable_functions__:
|
||||
reply_str += "{}. {}{}:\n{}\n\n".format(index, ("(已禁用) " if not func['enabled'] else ""), func['name'], func['description'])
|
||||
index += 1
|
||||
|
||||
reply = [reply_str]
|
||||
|
||||
return True, reply
|
|
@ -1,198 +0,0 @@
|
|||
from ....plugin import host as plugin_host
|
||||
from ....utils import updater
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="plugin",
|
||||
description="插件管理",
|
||||
usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class PluginCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
reply = []
|
||||
plugin_list = plugin_host.__plugins__
|
||||
if len(ctx.params) == 0:
|
||||
# 列出所有插件
|
||||
|
||||
reply_str = "[bot]所有插件({}):\n".format(len(plugin_host.__plugins__))
|
||||
idx = 0
|
||||
for key in plugin_host.iter_plugins_name():
|
||||
plugin = plugin_list[key]
|
||||
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
|
||||
.format((idx+1), plugin['name'],
|
||||
"[已禁用]" if not plugin['enabled'] else "",
|
||||
plugin['description'],
|
||||
plugin['version'], plugin['author'])
|
||||
|
||||
if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
|
||||
remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1]))
|
||||
if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT":
|
||||
reply_str += "源码: "+remote_url+"\n"
|
||||
|
||||
idx += 1
|
||||
|
||||
reply = [reply_str]
|
||||
return True, reply
|
||||
else:
|
||||
return False, []
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=PluginCommand,
|
||||
name="get",
|
||||
description="安装插件",
|
||||
usage="!plugin get <插件仓库地址>",
|
||||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class PluginGetCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import threading
|
||||
import logging
|
||||
import pkg.utils.context
|
||||
|
||||
if len(ctx.crt_params) == 0:
|
||||
reply = ["[bot]err: 请提供插件仓库地址"]
|
||||
return True, reply
|
||||
|
||||
reply = []
|
||||
def closure():
|
||||
try:
|
||||
plugin_host.install_plugin(ctx.crt_params[0])
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("插件安装成功,请发送 !reload 命令重载插件")
|
||||
except Exception as e:
|
||||
logging.error("插件安装失败:{}".format(e))
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("插件安装失败:{}".format(e))
|
||||
|
||||
threading.Thread(target=closure, args=()).start()
|
||||
reply = ["[bot]正在安装插件..."]
|
||||
return True, reply
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=PluginCommand,
|
||||
name="update",
|
||||
description="更新指定插件或全部插件",
|
||||
usage="!plugin update",
|
||||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class PluginUpdateCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import threading
|
||||
import logging
|
||||
plugin_list = plugin_host.__plugins__
|
||||
|
||||
reply = []
|
||||
|
||||
if len(ctx.crt_params) > 0:
|
||||
def closure():
|
||||
try:
|
||||
import pkg.utils.context
|
||||
|
||||
updated = []
|
||||
|
||||
if ctx.crt_params[0] == 'all':
|
||||
for key in plugin_list:
|
||||
plugin_host.update_plugin(key)
|
||||
updated.append(key)
|
||||
else:
|
||||
plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(ctx.crt_params[0])
|
||||
|
||||
if plugin_path_name is not None:
|
||||
plugin_host.update_plugin(ctx.crt_params[0])
|
||||
updated.append(ctx.crt_params[0])
|
||||
else:
|
||||
raise Exception("未找到插件: {}".format(ctx.crt_params[0]))
|
||||
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("已更新插件: {}, 请发送 !reload 重载插件".format(", ".join(updated)))
|
||||
except Exception as e:
|
||||
logging.error("插件更新失败:{}".format(e))
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("插件更新失败:{} 请使用 !plugin 命令确认插件名称或尝试手动更新插件".format(e))
|
||||
|
||||
reply = ["[bot]正在更新插件,请勿重复发起..."]
|
||||
threading.Thread(target=closure).start()
|
||||
else:
|
||||
reply = ["[bot]请指定要更新的插件, 或使用 !plugin update all 更新所有插件"]
|
||||
return True, reply
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=PluginCommand,
|
||||
name="del",
|
||||
description="删除插件",
|
||||
usage="!plugin del <插件名>",
|
||||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class PluginDelCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
plugin_list = plugin_host.__plugins__
|
||||
reply = []
|
||||
|
||||
if len(ctx.crt_params) < 1:
|
||||
reply = ["[bot]err: 未指定插件名"]
|
||||
else:
|
||||
plugin_name = ctx.crt_params[0]
|
||||
if plugin_name in plugin_list:
|
||||
unin_path = plugin_host.uninstall_plugin(plugin_name)
|
||||
reply = ["[bot]已删除插件: {} ({}), 请发送 !reload 重载插件".format(plugin_name, unin_path)]
|
||||
else:
|
||||
reply = ["[bot]err:未找到插件: {}, 请使用!plugin命令查看插件列表".format(plugin_name)]
|
||||
|
||||
return True, reply
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=PluginCommand,
|
||||
name="on",
|
||||
description="启用指定插件",
|
||||
usage="!plugin on <插件名>",
|
||||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=PluginCommand,
|
||||
name="off",
|
||||
description="禁用指定插件",
|
||||
usage="!plugin off <插件名>",
|
||||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class PluginOnOffCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.plugin.switch as plugin_switch
|
||||
|
||||
plugin_list = plugin_host.__plugins__
|
||||
reply = []
|
||||
|
||||
print(ctx.params)
|
||||
new_status = ctx.params[0] == 'on'
|
||||
|
||||
if len(ctx.crt_params) < 1:
|
||||
reply = ["[bot]err: 未指定插件名"]
|
||||
else:
|
||||
plugin_name = ctx.crt_params[0]
|
||||
if plugin_name in plugin_list:
|
||||
plugin_list[plugin_name]['enabled'] = new_status
|
||||
|
||||
for func in plugin_host.__callable_functions__:
|
||||
if func['name'].startswith(plugin_name+"-"):
|
||||
func['enabled'] = new_status
|
||||
|
||||
plugin_switch.dump_switch()
|
||||
reply = ["[bot]已{}插件: {}".format("启用" if new_status else "禁用", plugin_name)]
|
||||
else:
|
||||
reply = ["[bot]err:未找到插件: {}, 请使用!plugin命令查看插件列表".format(plugin_name)]
|
||||
|
||||
return True, reply
|
||||
|
|
@ -1,71 +0,0 @@
|
|||
from .. import aamgr
|
||||
from ....utils import context
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="default",
|
||||
description="操作情景预设",
|
||||
usage="!default\n!default set [指定情景预设为默认]",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class DefaultCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
params = ctx.params
|
||||
reply = []
|
||||
|
||||
config = context.get_config_manager().data
|
||||
|
||||
if len(params) == 0:
|
||||
# 输出目前所有情景预设
|
||||
import pkg.openai.dprompt as dprompt
|
||||
reply_str = "[bot]当前所有情景预设({}模式):\n\n".format(config['preset_mode'])
|
||||
|
||||
prompts = dprompt.mode_inst().list()
|
||||
|
||||
for key in prompts:
|
||||
pro = prompts[key]
|
||||
reply_str += "名称: {}".format(key)
|
||||
|
||||
for r in pro:
|
||||
reply_str += "\n - [{}]: {}".format(r['role'], r['content'])
|
||||
|
||||
reply_str += "\n\n"
|
||||
|
||||
reply_str += "\n当前默认情景预设:{}\n".format(dprompt.mode_inst().get_using_name())
|
||||
reply_str += "请使用 !default set <情景预设名称> 来设置默认情景预设"
|
||||
reply = [reply_str]
|
||||
else:
|
||||
return False, []
|
||||
|
||||
return True, reply
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=DefaultCommand,
|
||||
name="set",
|
||||
description="设置默认情景预设",
|
||||
usage="!default set <情景预设名称>",
|
||||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class DefaultSetCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
reply = []
|
||||
|
||||
if len(ctx.crt_params) == 0:
|
||||
reply = ["[bot]err: 请指定情景预设名称"]
|
||||
elif len(ctx.crt_params) > 0:
|
||||
import pkg.openai.dprompt as dprompt
|
||||
try:
|
||||
full_name = dprompt.mode_inst().set_using_name(ctx.crt_params[0])
|
||||
reply = ["[bot]已设置默认情景预设为:{}".format(full_name)]
|
||||
except Exception as e:
|
||||
reply = ["[bot]err: {}".format(e)]
|
||||
|
||||
return True, reply
|
|
@ -1,51 +0,0 @@
|
|||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="del",
|
||||
description="删除当前会话的历史记录",
|
||||
usage="!del <序号>\n!del all",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class DelCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
params = ctx.params
|
||||
reply = []
|
||||
if len(params) == 0:
|
||||
reply = ["[bot]参数不足, 格式: !del <序号>\n可以通过!list查看序号"]
|
||||
else:
|
||||
if params[0] == 'all':
|
||||
return False, []
|
||||
elif params[0].isdigit():
|
||||
if pkg.openai.session.get_session(session_name).delete_history(int(params[0])):
|
||||
reply = ["[bot]已删除历史会话 #{}".format(params[0])]
|
||||
else:
|
||||
reply = ["[bot]没有历史会话 #{}".format(params[0])]
|
||||
else:
|
||||
reply = ["[bot]参数错误, 格式: !del <序号>\n可以通过!list查看序号"]
|
||||
|
||||
return True, reply
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=DelCommand,
|
||||
name="all",
|
||||
description="删除当前会话的全部历史记录",
|
||||
usage="!del all",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class DelAllCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
reply = []
|
||||
pkg.openai.session.get_session(session_name).delete_all_history()
|
||||
reply = ["[bot]已删除所有历史会话"]
|
||||
return True, reply
|
|
@ -1,50 +0,0 @@
|
|||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="delhst",
|
||||
description="删除指定会话的所有历史记录",
|
||||
usage="!delhst <会话名称>\n!delhst all",
|
||||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class DelHistoryCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
import pkg.utils.context
|
||||
params = ctx.params
|
||||
reply = []
|
||||
if len(params) == 0:
|
||||
reply = [
|
||||
"[bot]err:请输入要删除的会话名: group_<群号> 或者 person_<QQ号>, 或使用 !delhst all 删除所有会话的历史记录"]
|
||||
else:
|
||||
if params[0] == 'all':
|
||||
return False, []
|
||||
else:
|
||||
if pkg.utils.context.get_database_manager().delete_all_history(params[0]):
|
||||
reply = ["[bot]已删除会话 {} 的所有历史记录".format(params[0])]
|
||||
else:
|
||||
reply = ["[bot]未找到会话 {} 的历史记录".format(params[0])]
|
||||
|
||||
return True, reply
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=DelHistoryCommand,
|
||||
name="all",
|
||||
description="删除所有会话的全部历史记录",
|
||||
usage="!delhst all",
|
||||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class DelAllHistoryCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.utils.context
|
||||
reply = []
|
||||
pkg.utils.context.get_database_manager().delete_all_session_history()
|
||||
reply = ["[bot]已删除所有会话的历史记录"]
|
||||
return True, reply
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
import datetime
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="last",
|
||||
description="切换前一次对话",
|
||||
usage="!last",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class LastCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
|
||||
reply = []
|
||||
result = pkg.openai.session.get_session(session_name).last_session()
|
||||
if result is None:
|
||||
reply = ["[bot]没有前一次的对话"]
|
||||
else:
|
||||
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
|
||||
'%Y-%m-%d %H:%M:%S')
|
||||
reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(datetime_str)]
|
||||
|
||||
return True, reply
|
|
@ -1,65 +0,0 @@
|
|||
import datetime
|
||||
import json
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name='list',
|
||||
description='列出当前会话的所有历史记录',
|
||||
usage='!list\n!list [页数]',
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class ListCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
params = ctx.params
|
||||
reply = []
|
||||
|
||||
pkg.openai.session.get_session(session_name).persistence()
|
||||
page = 0
|
||||
|
||||
if len(params) > 0:
|
||||
try:
|
||||
page = int(params[0])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
results = pkg.openai.session.get_session(session_name).list_history(page=page)
|
||||
if len(results) == 0:
|
||||
reply_str = "[bot]第{}页没有历史会话".format(page)
|
||||
else:
|
||||
reply_str = "[bot]历史会话 第{}页:\n".format(page)
|
||||
current = -1
|
||||
for i in range(len(results)):
|
||||
# 时间(使用create_timestamp转换) 序号 部分内容
|
||||
datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp'])
|
||||
msg = ""
|
||||
|
||||
msg = json.loads(results[i]['prompt'])
|
||||
|
||||
if len(msg) >= 2:
|
||||
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
|
||||
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
msg[0]['content'])
|
||||
else:
|
||||
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
|
||||
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"无内容")
|
||||
if results[i]['create_timestamp'] == pkg.openai.session.get_session(
|
||||
session_name).create_timestamp:
|
||||
current = i + page * 10
|
||||
|
||||
reply_str += "\n以上信息倒序排列"
|
||||
if current != -1:
|
||||
reply_str += ",当前会话是 #{}\n".format(current)
|
||||
else:
|
||||
reply_str += ",当前处于全新会话或不在此页"
|
||||
|
||||
reply = [reply_str]
|
||||
|
||||
return True, reply
|
|
@ -1,29 +0,0 @@
|
|||
import datetime
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="next",
|
||||
description="切换后一次对话",
|
||||
usage="!next",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class NextCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
reply = []
|
||||
|
||||
result = pkg.openai.session.get_session(session_name).next_session()
|
||||
if result is None:
|
||||
reply = ["[bot]没有后一次的对话"]
|
||||
else:
|
||||
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
|
||||
'%Y-%m-%d %H:%M:%S')
|
||||
reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(datetime_str)]
|
||||
|
||||
return True, reply
|
|
@ -1,31 +0,0 @@
|
|||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="prompt",
|
||||
description="获取当前会话的前文",
|
||||
usage="!prompt",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class PromptCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
params = ctx.params
|
||||
reply = []
|
||||
|
||||
msgs = ""
|
||||
session: list = pkg.openai.session.get_session(session_name).prompt
|
||||
for msg in session:
|
||||
if len(params) != 0 and params[0] in ['-all', '-a']:
|
||||
msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content'])
|
||||
elif len(msg['content']) > 30:
|
||||
msgs = msgs + "[{}]: {}...\n\n".format(msg['role'], msg['content'][:30])
|
||||
else:
|
||||
msgs = msgs + "[{}]: {}\n\n".format(msg['role'], msg['content'])
|
||||
reply = ["[bot]当前对话所有内容:\n{}".format(msgs)]
|
||||
|
||||
return True, reply
|
|
@ -1,33 +0,0 @@
|
|||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="resend",
|
||||
description="重新获取上一次问题的回复",
|
||||
usage="!resend",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class ResendCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
from ....openai import session as openai_session
|
||||
from ....utils import context
|
||||
from ....qqbot import message
|
||||
|
||||
session_name = ctx.session_name
|
||||
reply = []
|
||||
|
||||
session = openai_session.get_session(session_name)
|
||||
to_send = session.undo()
|
||||
|
||||
mgr = context.get_qqbot_manager()
|
||||
|
||||
config = context.get_config_manager().data
|
||||
|
||||
reply = message.process_normal_message(to_send, mgr, config,
|
||||
ctx.launcher_type, ctx.launcher_id,
|
||||
ctx.sender_id)
|
||||
|
||||
return True, reply
|
|
@ -1,35 +0,0 @@
|
|||
import tips as tips_custom
|
||||
|
||||
from .. import aamgr
|
||||
from ....openai import session
|
||||
from ....utils import context
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name='reset',
|
||||
description='重置当前会话',
|
||||
usage='!reset',
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class ResetCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
params = ctx.params
|
||||
session_name = ctx.session_name
|
||||
|
||||
reply = ""
|
||||
|
||||
if len(params) == 0:
|
||||
session.get_session(session_name).reset(explicit=True)
|
||||
reply = [tips_custom.command_reset_message]
|
||||
else:
|
||||
try:
|
||||
import pkg.openai.dprompt as dprompt
|
||||
session.get_session(session_name).reset(explicit=True, use_prompt=params[0])
|
||||
reply = [tips_custom.command_reset_name_message+"{}".format(dprompt.mode_inst().get_full_name(params[0]))]
|
||||
except Exception as e:
|
||||
reply = ["[bot]会话重置失败:{}".format(e)]
|
||||
|
||||
return True, reply
|
|
@ -1,93 +0,0 @@
|
|||
import json
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
def config_operation(cmd, params):
|
||||
reply = []
|
||||
import pkg.utils.context
|
||||
# config = pkg.utils.context.get_config()
|
||||
cfg_mgr = pkg.utils.context.get_config_manager()
|
||||
|
||||
false = False
|
||||
true = True
|
||||
|
||||
reply_str = ""
|
||||
if len(params) == 0:
|
||||
reply = ["[bot]err:请输入!cmd cfg查看使用方法"]
|
||||
else:
|
||||
cfg_name = params[0]
|
||||
if cfg_name == 'all':
|
||||
reply_str = "[bot]所有配置项:\n\n"
|
||||
for cfg in cfg_mgr.data.keys():
|
||||
if not cfg.startswith('__') and not cfg == 'logging':
|
||||
# 根据配置项类型进行格式化,如果是字典则转换为json并格式化
|
||||
if isinstance(cfg_mgr.data[cfg], str):
|
||||
reply_str += "{}: \"{}\"\n".format(cfg, cfg_mgr.data[cfg])
|
||||
elif isinstance(cfg_mgr.data[cfg], dict):
|
||||
# 不进行unicode转义,并格式化
|
||||
reply_str += "{}: {}\n".format(cfg,
|
||||
json.dumps(cfg_mgr.data[cfg],
|
||||
ensure_ascii=False, indent=4))
|
||||
else:
|
||||
reply_str += "{}: {}\n".format(cfg, cfg_mgr.data[cfg])
|
||||
reply = [reply_str]
|
||||
else:
|
||||
cfg_entry_path = cfg_name.split('.')
|
||||
|
||||
try:
|
||||
if len(params) == 1: # 未指定配置值,返回配置项值
|
||||
cfg_entry = cfg_mgr.data[cfg_entry_path[0]]
|
||||
if len(cfg_entry_path) > 1:
|
||||
for i in range(1, len(cfg_entry_path)):
|
||||
cfg_entry = cfg_entry[cfg_entry_path[i]]
|
||||
|
||||
if isinstance(cfg_entry, str):
|
||||
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, cfg_entry)
|
||||
elif isinstance(cfg_entry, dict):
|
||||
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
|
||||
json.dumps(cfg_entry,
|
||||
ensure_ascii=False, indent=4))
|
||||
else:
|
||||
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, cfg_entry)
|
||||
reply = [reply_str]
|
||||
else:
|
||||
cfg_value = " ".join(params[1:])
|
||||
|
||||
cfg_value = eval(cfg_value)
|
||||
|
||||
cfg_entry = cfg_mgr.data[cfg_entry_path[0]]
|
||||
if len(cfg_entry_path) > 1:
|
||||
for i in range(1, len(cfg_entry_path) - 1):
|
||||
cfg_entry = cfg_entry[cfg_entry_path[i]]
|
||||
if isinstance(cfg_entry[cfg_entry_path[-1]], type(cfg_value)):
|
||||
cfg_entry[cfg_entry_path[-1]] = cfg_value
|
||||
reply = ["[bot]配置项{}修改成功".format(cfg_name)]
|
||||
else:
|
||||
reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
|
||||
else:
|
||||
cfg_mgr.data[cfg_entry_path[0]] = cfg_value
|
||||
reply = ["[bot]配置项{}修改成功".format(cfg_name)]
|
||||
except KeyError:
|
||||
reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
|
||||
except NameError:
|
||||
reply = ["[bot]err:值{}不合法(字符串需要使用双引号包裹)".format(cfg_value)]
|
||||
except ValueError:
|
||||
reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
|
||||
|
||||
return reply
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="cfg",
|
||||
description="配置项管理",
|
||||
usage="!cfg <配置项> [配置值]\n!cfg all",
|
||||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class CfgCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
return True, config_operation(ctx.command, ctx.params)
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="cmd",
|
||||
description="显示命令列表",
|
||||
usage="!cmd\n!cmd <命令名称>",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class CmdCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
command_list = aamgr.__command_list__
|
||||
|
||||
reply = []
|
||||
|
||||
if len(ctx.params) == 0:
|
||||
reply_str = "[bot]当前所有命令:\n\n"
|
||||
|
||||
# 遍历顶级命令
|
||||
for key in command_list:
|
||||
command = command_list[key]
|
||||
if command['parent'] is None:
|
||||
reply_str += "!{} - {}\n".format(key, command['description'])
|
||||
|
||||
reply_str += "\n请使用 !cmd <命令名称> 来查看命令的详细信息"
|
||||
|
||||
reply = [reply_str]
|
||||
else:
|
||||
command_name = ctx.params[0]
|
||||
if command_name in command_list:
|
||||
reply = [command_list[command_name]['cls'].help()]
|
||||
else:
|
||||
reply = ["[bot]命令 {} 不存在".format(command_name)]
|
||||
|
||||
return True, reply
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="help",
|
||||
description="显示自定义的帮助信息",
|
||||
usage="!help",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class HelpCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import tips
|
||||
reply = ["[bot] "+tips.help_message + "\n请输入 !cmd 查看命令列表"]
|
||||
|
||||
# 警告config.help_message过时
|
||||
import config
|
||||
if hasattr(config, "help_message"):
|
||||
reply[0] += "\n\n警告:config.py中的help_message已过时,不再生效,请使用tips.py中的help_message替代"
|
||||
|
||||
return True, reply
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
import threading
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="reload",
|
||||
description="执行热重载",
|
||||
usage="!reload",
|
||||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class ReloadCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
reply = []
|
||||
|
||||
import pkg.utils.reloader
|
||||
def reload_task():
|
||||
pkg.utils.reloader.reload_all()
|
||||
|
||||
threading.Thread(target=reload_task, daemon=True).start()
|
||||
|
||||
return True, reply
|
|
@ -1,38 +0,0 @@
|
|||
import threading
|
||||
import traceback
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="update",
|
||||
description="更新程序",
|
||||
usage="!update",
|
||||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class UpdateCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
reply = []
|
||||
import pkg.utils.updater
|
||||
import pkg.utils.reloader
|
||||
import pkg.utils.context
|
||||
|
||||
def update_task():
|
||||
try:
|
||||
if pkg.utils.updater.update_all():
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("更新完成, 请手动重启程序。")
|
||||
else:
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("无新版本")
|
||||
except Exception as e0:
|
||||
traceback.print_exc()
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0))
|
||||
return
|
||||
|
||||
threading.Thread(target=update_task, daemon=True).start()
|
||||
|
||||
reply = ["[bot]正在更新,请耐心等待,请勿重复发起更新..."]
|
||||
|
||||
return True, reply
|
|
@ -1,33 +0,0 @@
|
|||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="usage",
|
||||
description="获取使用情况",
|
||||
usage="!usage",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class UsageCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import config
|
||||
import pkg.utils.context
|
||||
|
||||
reply = []
|
||||
|
||||
reply_str = "[bot]各api-key使用情况:\n\n"
|
||||
|
||||
api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key
|
||||
for key_name in api_keys:
|
||||
text_length = pkg.utils.context.get_openai_manager().audit_mgr \
|
||||
.get_text_length_of_key(api_keys[key_name])
|
||||
image_count = pkg.utils.context.get_openai_manager().audit_mgr \
|
||||
.get_image_count_of_key(api_keys[key_name])
|
||||
reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length),
|
||||
int(image_count))
|
||||
|
||||
reply = [reply_str]
|
||||
|
||||
return True, reply
|
|
@ -1,27 +0,0 @@
|
|||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="version",
|
||||
description="查看版本信息",
|
||||
usage="!version",
|
||||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class VersionCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
reply = []
|
||||
import pkg.utils.updater
|
||||
|
||||
reply_str = "[bot]当前版本:\n{}\n".format(pkg.utils.updater.get_current_version_info())
|
||||
try:
|
||||
if pkg.utils.updater.is_new_version_available():
|
||||
reply_str += "\n有新版本可用,请使用命令 !update 进行更新"
|
||||
except:
|
||||
pass
|
||||
|
||||
reply = [reply_str]
|
||||
|
||||
return True, reply
|
|
@ -1,49 +0,0 @@
|
|||
# 命令处理模块
|
||||
import logging
|
||||
|
||||
from ..qqbot.cmds import aamgr as cmdmgr
|
||||
|
||||
|
||||
def process_command(session_name: str, text_message: str, mgr, config: dict,
|
||||
launcher_type: str, launcher_id: int, sender_id: int, is_admin: bool) -> list:
|
||||
reply = []
|
||||
try:
|
||||
logging.info(
|
||||
"[{}]发起命令:{}".format(session_name, text_message[:min(20, len(text_message))] + (
|
||||
"..." if len(text_message) > 20 else "")))
|
||||
|
||||
cmd = text_message[1:].strip().split(' ')[0]
|
||||
|
||||
params = text_message[1:].strip().split(' ')[1:]
|
||||
|
||||
# 把!~开头的转换成!cfg
|
||||
if cmd.startswith('~'):
|
||||
params = [cmd[1:]] + params
|
||||
cmd = 'cfg'
|
||||
|
||||
# 包装参数
|
||||
context = cmdmgr.Context(
|
||||
command=cmd,
|
||||
crt_command=cmd,
|
||||
params=params,
|
||||
crt_params=params[:],
|
||||
session_name=session_name,
|
||||
text_message=text_message,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=sender_id,
|
||||
is_admin=is_admin,
|
||||
privilege=2 if is_admin else 1, # 普通用户1,管理员2
|
||||
)
|
||||
try:
|
||||
reply = cmdmgr.execute(context)
|
||||
except cmdmgr.CommandPrivilegeError as e:
|
||||
reply = ["{}".format(e)]
|
||||
|
||||
return reply
|
||||
except Exception as e:
|
||||
mgr.notify_admin("{}命令执行失败:{}".format(session_name, e))
|
||||
logging.exception(e)
|
||||
reply = ["[bot]err:{}".format(e)]
|
||||
|
||||
return reply
|
|
@ -12,10 +12,7 @@ import func_timeout
|
|||
|
||||
from ..openai import session as openai_session
|
||||
|
||||
from ..qqbot import process as processor
|
||||
from ..utils import context
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
import tips as tips_custom
|
||||
from ..qqbot import adapter as msadapter
|
||||
from .ratelim import ratelim
|
||||
|
@ -25,28 +22,20 @@ from ..core import app, entities as core_entities
|
|||
|
||||
# 控制QQ消息输入输出的类
|
||||
class QQBotManager:
|
||||
retry = 3
|
||||
|
||||
|
||||
adapter: msadapter.MessageSourceAdapter = None
|
||||
|
||||
bot_account_id: int = 0
|
||||
|
||||
ban_person = []
|
||||
ban_group = []
|
||||
|
||||
# modern
|
||||
ap: app.Application = None
|
||||
|
||||
ratelimiter: ratelim.RateLimiter = None
|
||||
|
||||
def __init__(self, first_time_init=True, ap: app.Application = None):
|
||||
config = context.get_config_manager().data
|
||||
def __init__(self, ap: app.Application = None):
|
||||
|
||||
self.ap = ap
|
||||
self.ratelimiter = ratelim.RateLimiter(ap)
|
||||
|
||||
self.timeout = config['process_message_timeout']
|
||||
self.retry = config['retry_times']
|
||||
|
||||
async def initialize(self):
|
||||
await self.ratelimiter.initialize()
|
||||
|
@ -69,10 +58,6 @@ class QQBotManager:
|
|||
from ..utils.center import apigroup
|
||||
apigroup.APIGroup._runtime_info['account_id'] = "{}".format(self.bot_account_id)
|
||||
|
||||
context.set_qqbot_manager(self)
|
||||
|
||||
# 注册诸事件
|
||||
# Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码
|
||||
async def on_friend_message(event: FriendMessage):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
|
@ -144,90 +129,6 @@ class QQBotManager:
|
|||
quote_origin=True if config['quote_origin'] and check_quote else False
|
||||
)
|
||||
|
||||
async def common_process(
|
||||
self,
|
||||
launcher_type: str,
|
||||
launcher_id: int,
|
||||
text_message: str,
|
||||
message_chain: MessageChain,
|
||||
sender_id: int
|
||||
) -> mirai.MessageChain:
|
||||
"""
|
||||
私聊群聊通用消息处理方法
|
||||
"""
|
||||
# 检查bansess
|
||||
if await self.bansess_mgr.is_banned(launcher_type, launcher_id, sender_id):
|
||||
self.ap.logger.info("根据禁用列表忽略{}_{}的消息".format(launcher_type, launcher_id))
|
||||
return []
|
||||
|
||||
if mirai.Image in message_chain:
|
||||
return []
|
||||
elif sender_id == self.bot_account_id:
|
||||
return []
|
||||
else:
|
||||
# 超时则重试,重试超过次数则放弃
|
||||
failed = 0
|
||||
for i in range(self.retry):
|
||||
try:
|
||||
reply = await processor.process_message(launcher_type, launcher_id, text_message, message_chain,
|
||||
sender_id)
|
||||
return reply
|
||||
|
||||
# TODO openai 超时处理
|
||||
except func_timeout.FunctionTimedOut:
|
||||
logging.warning("{}_{}: 超时,重试中({})".format(launcher_type, launcher_id, i))
|
||||
openai_session.get_session("{}_{}".format(launcher_type, launcher_id)).release_response_lock()
|
||||
if "{}_{}".format(launcher_type, launcher_id) in processor.processing:
|
||||
processor.processing.remove("{}_{}".format(launcher_type, launcher_id))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
if failed == self.retry:
|
||||
openai_session.get_session("{}_{}".format(launcher_type, launcher_id)).release_response_lock()
|
||||
await self.notify_admin("{} 请求超时".format("{}_{}".format(launcher_type, launcher_id)))
|
||||
reply = [tips_custom.reply_message]
|
||||
|
||||
# 私聊消息处理
|
||||
async def on_person_message(self, event: MessageEvent):
|
||||
reply = ''
|
||||
|
||||
reply = await self.common_process(
|
||||
launcher_type="person",
|
||||
launcher_id=event.sender.id,
|
||||
text_message=str(event.message_chain),
|
||||
message_chain=event.message_chain,
|
||||
sender_id=event.sender.id
|
||||
)
|
||||
|
||||
if reply:
|
||||
await self.send(event, reply, check_quote=False, check_at_sender=False)
|
||||
|
||||
# 群消息处理
|
||||
async def on_group_message(self, event: GroupMessage):
|
||||
reply = ''
|
||||
|
||||
text = str(event.message_chain).strip()
|
||||
|
||||
rule_check_res = await self.resprule_chkr.check(
|
||||
text,
|
||||
event.message_chain,
|
||||
event.group.id,
|
||||
event.sender.id
|
||||
)
|
||||
|
||||
if rule_check_res.matching:
|
||||
text = str(rule_check_res.replacement).strip()
|
||||
reply = await self.common_process(
|
||||
launcher_type="group",
|
||||
launcher_id=event.group.id,
|
||||
text_message=text,
|
||||
message_chain=rule_check_res.replacement,
|
||||
sender_id=event.sender.id
|
||||
)
|
||||
|
||||
if reply:
|
||||
await self.send(event, reply)
|
||||
|
||||
# 通知系统管理员
|
||||
async def notify_admin(self, message: str):
|
||||
await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
|
||||
|
|
|
@ -1,134 +0,0 @@
|
|||
# 普通消息处理模块
|
||||
import logging
|
||||
|
||||
import openai
|
||||
|
||||
from ..utils import context
|
||||
from ..openai import session as openai_session
|
||||
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
import tips as tips_custom
|
||||
|
||||
|
||||
def handle_exception(notify_admin: str = "", set_reply: str = "") -> list:
|
||||
"""处理异常,当notify_admin不为空时,会通知管理员,返回通知用户的消息"""
|
||||
config = context.get_config_manager().data
|
||||
context.get_qqbot_manager().notify_admin(notify_admin)
|
||||
if config['hide_exce_info_to_user']:
|
||||
return [tips_custom.alter_tip_message] if tips_custom.alter_tip_message else []
|
||||
else:
|
||||
return [set_reply]
|
||||
|
||||
|
||||
def process_normal_message(text_message: str, mgr, config: dict, launcher_type: str,
|
||||
launcher_id: int, sender_id: int) -> list:
|
||||
session_name = f"{launcher_type}_{launcher_id}"
|
||||
logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + (
|
||||
"..." if len(text_message) > 20 else "")))
|
||||
|
||||
session = openai_session.get_session(session_name)
|
||||
|
||||
unexpected_exception_times = 0
|
||||
|
||||
max_unexpected_exception_times = 3
|
||||
|
||||
reply = []
|
||||
while True:
|
||||
if unexpected_exception_times >= max_unexpected_exception_times:
|
||||
reply = handle_exception(notify_admin=f"{session_name},多次尝试失败。", set_reply=f"[bot]多次尝试失败,请重试或联系管理员")
|
||||
break
|
||||
try:
|
||||
prefix = "[GPT]" if config['show_prefix'] else ""
|
||||
|
||||
text, finish_reason, funcs = session.query(text_message)
|
||||
|
||||
# 触发插件事件
|
||||
args = {
|
||||
"launcher_type": launcher_type,
|
||||
"launcher_id": launcher_id,
|
||||
"sender_id": sender_id,
|
||||
"session": session,
|
||||
"prefix": prefix,
|
||||
"response_text": text,
|
||||
"finish_reason": finish_reason,
|
||||
"funcs_called": funcs,
|
||||
}
|
||||
|
||||
event = plugin_host.emit(plugin_models.NormalMessageResponded, **args)
|
||||
|
||||
if event.get_return_value("prefix") is not None:
|
||||
prefix = event.get_return_value("prefix")
|
||||
|
||||
if event.get_return_value("reply") is not None:
|
||||
reply = event.get_return_value("reply")
|
||||
|
||||
if not event.is_prevented_default():
|
||||
reply = [prefix + text]
|
||||
|
||||
except openai.APIConnectionError as e:
|
||||
err_msg = str(e)
|
||||
if err_msg.__contains__('Error communicating with OpenAI'):
|
||||
reply = handle_exception("{}会话调用API失败:{}\n您的网络无法访问OpenAI接口或网络代理不正常".format(session_name, e),
|
||||
"[bot]err:调用API失败,请重试或联系管理员,或等待修复")
|
||||
else:
|
||||
reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), "[bot]err:调用API失败,请重试或联系管理员,或等待修复")
|
||||
except openai.RateLimitError as e:
|
||||
logging.debug(type(e))
|
||||
logging.debug(e.error['message'])
|
||||
|
||||
if 'message' in e.error and e.error['message'].__contains__('You exceeded your current quota'):
|
||||
# 尝试切换api-key
|
||||
current_key_name = context.get_openai_manager().key_mgr.get_key_name(
|
||||
context.get_openai_manager().key_mgr.using_key
|
||||
)
|
||||
context.get_openai_manager().key_mgr.set_current_exceeded()
|
||||
|
||||
# 触发插件事件
|
||||
args = {
|
||||
'key_name': current_key_name,
|
||||
'usage': context.get_openai_manager().audit_mgr
|
||||
.get_usage(context.get_openai_manager().key_mgr.get_using_key_md5()),
|
||||
'exceeded_keys': context.get_openai_manager().key_mgr.exceeded,
|
||||
}
|
||||
event = plugin_host.emit(plugin_models.KeyExceeded, **args)
|
||||
|
||||
if not event.is_prevented_default():
|
||||
switched, name = context.get_openai_manager().key_mgr.auto_switch()
|
||||
|
||||
if not switched:
|
||||
reply = handle_exception(
|
||||
"api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key;如果你认为这是误判,请尝试重启程序。".format(
|
||||
current_key_name), "[bot]err:API调用额度超额,请联系管理员,或等待修复")
|
||||
else:
|
||||
openai.api_key = context.get_openai_manager().key_mgr.get_using_key()
|
||||
mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name))
|
||||
reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"]
|
||||
continue
|
||||
elif 'message' in e.error and e.error['message'].__contains__('You can retry your request'):
|
||||
# 重试
|
||||
unexpected_exception_times += 1
|
||||
continue
|
||||
elif 'message' in e.error and e.error['message']\
|
||||
.__contains__('The server had an error while processing your request'):
|
||||
# 重试
|
||||
unexpected_exception_times += 1
|
||||
continue
|
||||
else:
|
||||
reply = handle_exception("{}会话调用API失败:{}".format(session_name, e),
|
||||
"[bot]err:RateLimitError,请重试或联系作者,或等待修复")
|
||||
except openai.BadRequestError as e:
|
||||
if config['auto_reset'] and "This model's maximum context length is" in str(e):
|
||||
session.reset(persist=True)
|
||||
reply = [tips_custom.session_auto_reset_message]
|
||||
else:
|
||||
reply = handle_exception("{}API调用参数错误:{}\n".format(
|
||||
session_name, e), "[bot]err:API调用参数错误,请联系管理员,或等待修复")
|
||||
except openai.APIStatusError as e:
|
||||
reply = handle_exception("{}API调用服务不可用:{}".format(session_name, e), "[bot]err:API调用服务不可用,请重试或联系管理员,或等待修复")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
reply = handle_exception("{}会话处理异常:{}".format(session_name, e), "[bot]err:{}".format(e))
|
||||
break
|
||||
|
||||
return reply
|
|
@ -1,180 +0,0 @@
|
|||
# 此模块提供了消息处理的具体逻辑的接口
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import mirai
|
||||
import logging
|
||||
|
||||
from ..qqbot import command, message
|
||||
from ..openai import session as openai_session
|
||||
from ..utils import context
|
||||
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
import tips as tips_custom
|
||||
from ..core import app
|
||||
# from .cntfilter import entities
|
||||
|
||||
processing = []
|
||||
|
||||
|
||||
def is_admin(qq: int) -> bool:
|
||||
"""兼容list和int类型的管理员判断"""
|
||||
config = context.get_config_manager().data
|
||||
if type(config['admin_qq']) == list:
|
||||
return qq in config['admin_qq']
|
||||
else:
|
||||
return qq == config['admin_qq']
|
||||
|
||||
|
||||
async def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain,
|
||||
sender_id: int) -> list:
|
||||
global processing
|
||||
|
||||
mgr = context.get_qqbot_manager()
|
||||
|
||||
reply = []
|
||||
session_name = "{}_{}".format(launcher_type, launcher_id)
|
||||
|
||||
config = context.get_config_manager().data
|
||||
|
||||
if not config['wait_last_done'] and session_name in processing:
|
||||
return [mirai.Plain(tips_custom.message_drop_tip)]
|
||||
|
||||
# 检查是否被禁言
|
||||
if launcher_type == 'group':
|
||||
is_muted = await mgr.adapter.is_muted(launcher_id)
|
||||
if is_muted:
|
||||
logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id))
|
||||
return reply
|
||||
|
||||
cntfilter_res = await mgr.cntfilter_mgr.pre_process(text_message)
|
||||
if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT:
|
||||
if cntfilter_res.console_notice:
|
||||
mgr.ap.logger.info(cntfilter_res.console_notice)
|
||||
if cntfilter_res.user_notice:
|
||||
return [mirai.Plain(cntfilter_res.user_notice)]
|
||||
else:
|
||||
return []
|
||||
|
||||
openai_session.get_session(session_name).acquire_response_lock()
|
||||
|
||||
text_message = text_message.strip()
|
||||
|
||||
# 为强制消息延迟计时
|
||||
start_time = time.time()
|
||||
|
||||
# 处理消息
|
||||
try:
|
||||
|
||||
processing.append(session_name)
|
||||
try:
|
||||
msg_type = ''
|
||||
if text_message.startswith('!') or text_message.startswith("!"): # 命令
|
||||
msg_type = 'command'
|
||||
# 触发插件事件
|
||||
args = {
|
||||
'launcher_type': launcher_type,
|
||||
'launcher_id': launcher_id,
|
||||
'sender_id': sender_id,
|
||||
'command': text_message[1:].strip().split(' ')[0],
|
||||
'params': text_message[1:].strip().split(' ')[1:],
|
||||
'text_message': text_message,
|
||||
'is_admin': is_admin(sender_id),
|
||||
}
|
||||
event = plugin_host.emit(plugin_models.PersonCommandSent
|
||||
if launcher_type == 'person'
|
||||
else plugin_models.GroupCommandSent, **args)
|
||||
|
||||
if event.get_return_value("alter") is not None:
|
||||
text_message = event.get_return_value("alter")
|
||||
|
||||
# 取出插件提交的返回值赋值给reply
|
||||
if event.get_return_value("reply") is not None:
|
||||
reply = event.get_return_value("reply")
|
||||
|
||||
if not event.is_prevented_default():
|
||||
reply = command.process_command(session_name, text_message,
|
||||
mgr, config, launcher_type, launcher_id, sender_id, is_admin(sender_id))
|
||||
|
||||
else: # 消息
|
||||
msg_type = 'message'
|
||||
# 限速丢弃检查
|
||||
if not await mgr.ratelimiter.require(launcher_type, launcher_id):
|
||||
logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message))
|
||||
|
||||
return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else []
|
||||
|
||||
before = time.time()
|
||||
# 触发插件事件
|
||||
args = {
|
||||
"launcher_type": launcher_type,
|
||||
"launcher_id": launcher_id,
|
||||
"sender_id": sender_id,
|
||||
"text_message": text_message,
|
||||
}
|
||||
event = plugin_host.emit(plugin_models.PersonNormalMessageReceived
|
||||
if launcher_type == 'person'
|
||||
else plugin_models.GroupNormalMessageReceived, **args)
|
||||
|
||||
if event.get_return_value("alter") is not None:
|
||||
text_message = event.get_return_value("alter")
|
||||
|
||||
# 取出插件提交的返回值赋值给reply
|
||||
if event.get_return_value("reply") is not None:
|
||||
reply = event.get_return_value("reply")
|
||||
|
||||
if not event.is_prevented_default():
|
||||
reply = message.process_normal_message(text_message,
|
||||
mgr, config, launcher_type, launcher_id, sender_id)
|
||||
|
||||
if reply is not None and len(reply) > 0 and (type(reply[0]) == str or type(reply[0]) == mirai.Plain):
|
||||
if type(reply[0]) == mirai.Plain:
|
||||
reply[0] = reply[0].text
|
||||
logging.info(
|
||||
"回复[{}]文字消息:{}".format(session_name,
|
||||
reply[0][:min(100, len(reply[0]))] + (
|
||||
"..." if len(reply[0]) > 100 else "")))
|
||||
if msg_type == 'message':
|
||||
cntfilter_res = await mgr.cntfilter_mgr.post_process(reply[0])
|
||||
if cntfilter_res.level == entities.ManagerResultLevel.INTERRUPT:
|
||||
if cntfilter_res.console_notice:
|
||||
mgr.ap.logger.info(cntfilter_res.console_notice)
|
||||
if cntfilter_res.user_notice:
|
||||
return [mirai.Plain(cntfilter_res.user_notice)]
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
reply = [cntfilter_res.replacement]
|
||||
|
||||
reply = await mgr.longtext_pcs.check_and_process(reply[0])
|
||||
else:
|
||||
logging.info("回复[{}]消息".format(session_name))
|
||||
|
||||
finally:
|
||||
processing.remove(session_name)
|
||||
finally:
|
||||
openai_session.get_session(session_name).release_response_lock()
|
||||
|
||||
# 检查延迟时间
|
||||
if config['force_delay_range'][1] == 0:
|
||||
delay_time = 0
|
||||
else:
|
||||
import random
|
||||
|
||||
# 从延迟范围中随机取一个值(浮点)
|
||||
rdm = random.uniform(config['force_delay_range'][0], config['force_delay_range'][1])
|
||||
|
||||
spent = time.time() - start_time
|
||||
|
||||
# 如果花费时间小于延迟时间,则延迟
|
||||
delay_time = rdm - spent if rdm - spent > 0 else 0
|
||||
|
||||
# 延迟
|
||||
if delay_time > 0:
|
||||
logging.info("[风控] 强制延迟{:.2f}秒(如需关闭,请到config.py修改force_delay_range字段)".format(delay_time))
|
||||
time.sleep(delay_time)
|
||||
|
||||
return mirai.MessageChain(reply)
|
|
@ -1 +0,0 @@
|
|||
from .threadctl import ThreadCtl
|
|
@ -1,10 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from . import threadctl
|
||||
|
||||
from ..database import manager as db_mgr
|
||||
from ..openai import manager as openai_mgr
|
||||
from ..qqbot import manager as qqbot_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..plugin import host as plugin_host
|
||||
|
|
|
@ -1,71 +0,0 @@
|
|||
import logging
|
||||
import importlib
|
||||
import pkgutil
|
||||
import asyncio
|
||||
|
||||
from . import context
|
||||
from ..plugin import host as plugin_host
|
||||
|
||||
|
||||
def walk(module, prefix='', path_prefix=''):
|
||||
"""遍历并重载所有模块"""
|
||||
for item in pkgutil.iter_modules(module.__path__):
|
||||
if item.ispkg:
|
||||
|
||||
walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.', path_prefix + item.name + '/')
|
||||
else:
|
||||
logging.info('reload module: {}, path: {}'.format(prefix + item.name, path_prefix + item.name + '.py'))
|
||||
plugin_host.__current_module_path__ = "plugins/" + path_prefix + item.name + '.py'
|
||||
importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=['']))
|
||||
|
||||
|
||||
def reload_all(notify=True):
|
||||
# 解除bot的事件注册
|
||||
import pkg
|
||||
context.get_qqbot_manager().unsubscribe_all()
|
||||
# 执行关闭流程
|
||||
logging.info("执行程序关闭流程")
|
||||
import main
|
||||
main.stop()
|
||||
|
||||
# 删除所有已注册的命令
|
||||
import pkg.qqbot.cmds.aamgr as cmdsmgr
|
||||
cmdsmgr.__command_list__ = {}
|
||||
cmdsmgr.__tree_index__ = {}
|
||||
|
||||
# 重载所有模块
|
||||
context.context['exceeded_keys'] = context.get_openai_manager().key_mgr.exceeded
|
||||
this_context = context.context
|
||||
walk(pkg)
|
||||
importlib.reload(__import__("config-template"))
|
||||
importlib.reload(__import__('config'))
|
||||
importlib.reload(__import__('main'))
|
||||
importlib.reload(__import__('banlist'))
|
||||
importlib.reload(__import__('tips'))
|
||||
context.context = this_context
|
||||
|
||||
# 重载插件
|
||||
import plugins
|
||||
walk(plugins)
|
||||
|
||||
# 初始化相关文件
|
||||
main.check_file()
|
||||
|
||||
# 执行启动流程
|
||||
logging.info("执行程序启动流程")
|
||||
|
||||
context.get_thread_ctl().reload(
|
||||
admin_pool_num=4,
|
||||
user_pool_num=8
|
||||
)
|
||||
|
||||
def run_wrapper():
|
||||
asyncio.run(main.start_process(False))
|
||||
|
||||
context.get_thread_ctl().submit_sys_task(
|
||||
run_wrapper
|
||||
)
|
||||
|
||||
logging.info('程序启动完成')
|
||||
if notify:
|
||||
context.get_qqbot_manager().notify_admin("重载完成")
|
|
@ -1,93 +0,0 @@
|
|||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
class Pool:
|
||||
"""线程池结构"""
|
||||
pool_num:int = None
|
||||
ctl:ThreadPoolExecutor = None
|
||||
task_list:list = None
|
||||
task_list_lock:threading.Lock = None
|
||||
monitor_type = True
|
||||
|
||||
def __init__(self, pool_num):
|
||||
self.pool_num = pool_num
|
||||
self.ctl = ThreadPoolExecutor(max_workers = self.pool_num)
|
||||
self.task_list = []
|
||||
self.task_list_lock = threading.Lock()
|
||||
|
||||
def __thread_monitor__(self):
|
||||
while self.monitor_type:
|
||||
for t in self.task_list:
|
||||
if not t.done():
|
||||
continue
|
||||
try:
|
||||
self.task_list.pop(self.task_list.index(t))
|
||||
except:
|
||||
continue
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
class ThreadCtl:
|
||||
def __init__(self, sys_pool_num, admin_pool_num, user_pool_num):
|
||||
"""线程池控制类
|
||||
sys_pool_num:分配系统使用的线程池数量(>=8)
|
||||
admin_pool_num:用于处理管理员消息的线程池数量(>=1)
|
||||
user_pool_num:分配用于处理用户消息的线程池的数量(>=1)
|
||||
"""
|
||||
if sys_pool_num < 5:
|
||||
raise Exception("Too few system threads(sys_pool_num needs >= 8, but received {})".format(sys_pool_num))
|
||||
if admin_pool_num < 1:
|
||||
raise Exception("Too few admin threads(admin_pool_num needs >= 1, but received {})".format(admin_pool_num))
|
||||
if user_pool_num < 1:
|
||||
raise Exception("Too few user threads(user_pool_num needs >= 1, but received {})".format(admin_pool_num))
|
||||
self.__sys_pool__ = Pool(sys_pool_num)
|
||||
self.__admin_pool__ = Pool(admin_pool_num)
|
||||
self.__user_pool__ = Pool(user_pool_num)
|
||||
self.submit_sys_task(self.__sys_pool__.__thread_monitor__)
|
||||
self.submit_sys_task(self.__admin_pool__.__thread_monitor__)
|
||||
self.submit_sys_task(self.__user_pool__.__thread_monitor__)
|
||||
|
||||
def __submit__(self, pool: Pool, fn, /, *args, **kwargs ):
|
||||
t = pool.ctl.submit(fn, *args, **kwargs)
|
||||
pool.task_list_lock.acquire()
|
||||
pool.task_list.append(t)
|
||||
pool.task_list_lock.release()
|
||||
return t
|
||||
|
||||
def submit_sys_task(self, fn, /, *args, **kwargs):
|
||||
return self.__submit__(
|
||||
self.__sys_pool__,
|
||||
fn, *args, **kwargs
|
||||
)
|
||||
|
||||
def submit_admin_task(self, fn, /, *args, **kwargs):
|
||||
return self.__submit__(
|
||||
self.__admin_pool__,
|
||||
fn, *args, **kwargs
|
||||
)
|
||||
|
||||
def submit_user_task(self, fn, /, *args, **kwargs):
|
||||
return self.__submit__(
|
||||
self.__user_pool__,
|
||||
fn, *args, **kwargs
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
self.__user_pool__.ctl.shutdown(cancel_futures=True)
|
||||
self.__user_pool__.monitor_type = False
|
||||
self.__admin_pool__.ctl.shutdown(cancel_futures=True)
|
||||
self.__admin_pool__.monitor_type = False
|
||||
self.__sys_pool__.monitor_type = False
|
||||
self.__sys_pool__.ctl.shutdown(wait=True, cancel_futures=False)
|
||||
|
||||
def reload(self, admin_pool_num, user_pool_num):
|
||||
self.__user_pool__.ctl.shutdown(cancel_futures=True)
|
||||
self.__user_pool__.monitor_type = False
|
||||
self.__admin_pool__.ctl.shutdown(cancel_futures=True)
|
||||
self.__admin_pool__.monitor_type = False
|
||||
self.__admin_pool__ = Pool(admin_pool_num)
|
||||
self.__user_pool__ = Pool(user_pool_num)
|
||||
self.submit_sys_task(self.__admin_pool__.__thread_monitor__)
|
||||
self.submit_sys_task(self.__user_pool__.__thread_monitor__)
|
Loading…
Reference in New Issue
Block a user