feat: 支持聊天窗管理员修改运行时配置项

This commit is contained in:
Rock Chin 2023-01-04 21:46:01 +08:00
parent 2f42bbc6c8
commit 07923f71bd
2 changed files with 81 additions and 17 deletions

View File

@ -51,16 +51,6 @@ def dump_session(session_name: str):
del sessions[session_name]
# 从配置文件获取会话预设信息
def get_default_prompt():
import config
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \
and config.default_prompt != "" else '') + \
bot_name + ":好的\n"
# def blocked_func(lock: threading.Lock):
#
# def decorator(func):
@ -83,7 +73,7 @@ def get_default_prompt():
class Session:
name = ''
prompt = get_default_prompt()
prompt = ""
import config
@ -111,6 +101,15 @@ class Session:
self.response_lock.release()
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
# 从配置文件获取会话预设信息
def get_default_prompt(self):
config = pkg.utils.context.get_config()
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \
and config.default_prompt != "" else '') + \
bot_name + ":好的\n"
def __init__(self, name: str):
self.name = name
self.create_timestamp = int(time.time())
@ -118,6 +117,7 @@ class Session:
self.schedule()
self.response_lock = threading.Lock()
self.prompt = self.get_default_prompt()
# 设定检查session最后一次对话是否超过过期时间的计时器
def schedule(self):
@ -206,7 +206,7 @@ class Session:
# 持久化session
def persistence(self):
if self.prompt == get_default_prompt():
if self.prompt == self.get_default_prompt():
return
db_inst = pkg.utils.context.get_database_manager()
@ -221,14 +221,14 @@ class Session:
# 重置session
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True):
if self.prompt != get_default_prompt():
if self.prompt != self.get_default_prompt():
self.persistence()
if explicit:
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
if expired:
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
self.prompt = get_default_prompt()
self.prompt = self.get_default_prompt()
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.just_switched_to_exist_session = False
@ -274,7 +274,7 @@ class Session:
def list_history(self, capacity: int = 10, page: int = 0):
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page,
get_default_prompt())
self.get_default_prompt())
def draw_image(self, prompt: str):
return pkg.utils.context.get_openai_manager().request_image(prompt)

View File

@ -1,6 +1,7 @@
# 此模块提供了消息处理的具体逻辑的接口
import asyncio
import datetime
import json
import threading
from func_timeout import func_set_timeout
@ -12,7 +13,7 @@ from mirai import Image, MessageChain
# 这里不使用动态引入config
# 因为在这里动态引入会卡死程序
# 而此模块静态引用config与动态引入的表现一致
import config
import config as config_init_import
import pkg.openai.session
import pkg.openai.manager
@ -23,7 +24,7 @@ import pkg.utils.context
processing = []
@func_set_timeout(config.process_message_timeout)
@func_set_timeout(config_init_import.process_message_timeout)
def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: MessageChain,
sender_id: int) -> MessageChain:
global processing
@ -51,6 +52,8 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
processing.append(session_name)
config = pkg.utils.context.get_config()
try:
if text_message.startswith('!') or text_message.startswith(""): # 指令
@ -189,6 +192,67 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
pkg.utils.context.get_qqbot_manager().notify_admin("更新完成")
threading.Thread(target=update_task, daemon=True).start()
elif cmd == 'cfg' and launcher_type == 'person' and launcher_id == config.admin_qq:
reply_str = ""
if len(params) == 0:
reply = ["[bot]err:请输入配置项"]
else:
cfg_name = params[0]
if cfg_name == 'all':
reply_str = "[bot]所有配置项:\n\n"
for cfg in dir(config):
print(cfg)
if not cfg.startswith('__') and not cfg == 'logging':
# 根据配置项类型进行格式化如果是字典则转换为json并格式化
if isinstance(getattr(config, cfg), str):
reply_str += "{}: \"{}\"\n".format(cfg, getattr(config, cfg))
elif isinstance(getattr(config, cfg), dict):
# 不进行unicode转义并格式化
reply_str += "{}: {}\n".format(cfg,
json.dumps(getattr(config, cfg),
ensure_ascii=False, indent=4))
else:
reply_str += "{}: {}\n".format(cfg, getattr(config, cfg))
reply = [reply_str]
elif cfg_name in dir(config):
if len(params) == 1:
# 按照配置项类型进行格式化
if isinstance(getattr(config, cfg_name), str):
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, getattr(config, cfg_name))
elif isinstance(getattr(config, cfg_name), dict):
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
json.dumps(getattr(config, cfg_name),
ensure_ascii=False, indent=4))
else:
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, getattr(config, cfg_name))
reply = [reply_str]
else:
cfg_value = " ".join(params[1:])
# 类型转换如果是json则转换为字典
if cfg_value == 'true':
cfg_value = True
elif cfg_value == 'false':
cfg_value = False
elif cfg_value.isdigit():
cfg_value = int(cfg_value)
elif cfg_value.startswith('{') and cfg_value.endswith('}'):
cfg_value = json.loads(cfg_value)
else:
try:
cfg_value = float(cfg_value)
except ValueError:
pass
# 检查类型是否匹配
if isinstance(getattr(config, cfg_name), type(cfg_value)):
setattr(config, cfg_name, cfg_value)
pkg.utils.context.set_config(config)
reply = ["[bot]配置项{}修改成功".format(cfg_name)]
else:
reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
else:
reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
else:
reply = ["[bot]err:未知的指令或权限不足: "+cmd]
except Exception as e: