mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
feat: 支持聊天窗管理员修改运行时配置项
This commit is contained in:
parent
2f42bbc6c8
commit
07923f71bd
|
@ -51,16 +51,6 @@ def dump_session(session_name: str):
|
|||
del sessions[session_name]
|
||||
|
||||
|
||||
# 从配置文件获取会话预设信息
|
||||
def get_default_prompt():
|
||||
import config
|
||||
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
|
||||
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
|
||||
return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \
|
||||
and config.default_prompt != "" else '') + \
|
||||
bot_name + ":好的\n"
|
||||
|
||||
|
||||
# def blocked_func(lock: threading.Lock):
|
||||
#
|
||||
# def decorator(func):
|
||||
|
@ -83,7 +73,7 @@ def get_default_prompt():
|
|||
class Session:
|
||||
name = ''
|
||||
|
||||
prompt = get_default_prompt()
|
||||
prompt = ""
|
||||
|
||||
import config
|
||||
|
||||
|
@ -111,6 +101,15 @@ class Session:
|
|||
self.response_lock.release()
|
||||
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
|
||||
|
||||
# 从配置文件获取会话预设信息
|
||||
def get_default_prompt(self):
|
||||
config = pkg.utils.context.get_config()
|
||||
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
|
||||
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
|
||||
return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \
|
||||
and config.default_prompt != "" else '') + \
|
||||
bot_name + ":好的\n"
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.create_timestamp = int(time.time())
|
||||
|
@ -118,6 +117,7 @@ class Session:
|
|||
self.schedule()
|
||||
|
||||
self.response_lock = threading.Lock()
|
||||
self.prompt = self.get_default_prompt()
|
||||
|
||||
# 设定检查session最后一次对话是否超过过期时间的计时器
|
||||
def schedule(self):
|
||||
|
@ -206,7 +206,7 @@ class Session:
|
|||
|
||||
# 持久化session
|
||||
def persistence(self):
|
||||
if self.prompt == get_default_prompt():
|
||||
if self.prompt == self.get_default_prompt():
|
||||
return
|
||||
|
||||
db_inst = pkg.utils.context.get_database_manager()
|
||||
|
@ -221,14 +221,14 @@ class Session:
|
|||
|
||||
# 重置session
|
||||
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True):
|
||||
if self.prompt != get_default_prompt():
|
||||
if self.prompt != self.get_default_prompt():
|
||||
self.persistence()
|
||||
if explicit:
|
||||
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
|
||||
|
||||
if expired:
|
||||
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
|
||||
self.prompt = get_default_prompt()
|
||||
self.prompt = self.get_default_prompt()
|
||||
self.create_timestamp = int(time.time())
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
self.just_switched_to_exist_session = False
|
||||
|
@ -274,7 +274,7 @@ class Session:
|
|||
|
||||
def list_history(self, capacity: int = 10, page: int = 0):
|
||||
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page,
|
||||
get_default_prompt())
|
||||
self.get_default_prompt())
|
||||
|
||||
def draw_image(self, prompt: str):
|
||||
return pkg.utils.context.get_openai_manager().request_image(prompt)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# 此模块提供了消息处理的具体逻辑的接口
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import threading
|
||||
|
||||
from func_timeout import func_set_timeout
|
||||
|
@ -12,7 +13,7 @@ from mirai import Image, MessageChain
|
|||
# 这里不使用动态引入config
|
||||
# 因为在这里动态引入会卡死程序
|
||||
# 而此模块静态引用config与动态引入的表现一致
|
||||
import config
|
||||
import config as config_init_import
|
||||
|
||||
import pkg.openai.session
|
||||
import pkg.openai.manager
|
||||
|
@ -23,7 +24,7 @@ import pkg.utils.context
|
|||
processing = []
|
||||
|
||||
|
||||
@func_set_timeout(config.process_message_timeout)
|
||||
@func_set_timeout(config_init_import.process_message_timeout)
|
||||
def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: MessageChain,
|
||||
sender_id: int) -> MessageChain:
|
||||
global processing
|
||||
|
@ -51,6 +52,8 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
|||
|
||||
processing.append(session_name)
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
|
||||
try:
|
||||
|
||||
if text_message.startswith('!') or text_message.startswith("!"): # 指令
|
||||
|
@ -189,6 +192,67 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
|||
pkg.utils.context.get_qqbot_manager().notify_admin("更新完成")
|
||||
|
||||
threading.Thread(target=update_task, daemon=True).start()
|
||||
elif cmd == 'cfg' and launcher_type == 'person' and launcher_id == config.admin_qq:
|
||||
reply_str = ""
|
||||
if len(params) == 0:
|
||||
reply = ["[bot]err:请输入配置项"]
|
||||
else:
|
||||
cfg_name = params[0]
|
||||
if cfg_name == 'all':
|
||||
reply_str = "[bot]所有配置项:\n\n"
|
||||
for cfg in dir(config):
|
||||
print(cfg)
|
||||
if not cfg.startswith('__') and not cfg == 'logging':
|
||||
# 根据配置项类型进行格式化,如果是字典则转换为json并格式化
|
||||
if isinstance(getattr(config, cfg), str):
|
||||
reply_str += "{}: \"{}\"\n".format(cfg, getattr(config, cfg))
|
||||
elif isinstance(getattr(config, cfg), dict):
|
||||
# 不进行unicode转义,并格式化
|
||||
reply_str += "{}: {}\n".format(cfg,
|
||||
json.dumps(getattr(config, cfg),
|
||||
ensure_ascii=False, indent=4))
|
||||
else:
|
||||
reply_str += "{}: {}\n".format(cfg, getattr(config, cfg))
|
||||
reply = [reply_str]
|
||||
elif cfg_name in dir(config):
|
||||
if len(params) == 1:
|
||||
# 按照配置项类型进行格式化
|
||||
if isinstance(getattr(config, cfg_name), str):
|
||||
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, getattr(config, cfg_name))
|
||||
elif isinstance(getattr(config, cfg_name), dict):
|
||||
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
|
||||
json.dumps(getattr(config, cfg_name),
|
||||
ensure_ascii=False, indent=4))
|
||||
else:
|
||||
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, getattr(config, cfg_name))
|
||||
reply = [reply_str]
|
||||
else:
|
||||
cfg_value = " ".join(params[1:])
|
||||
# 类型转换,如果是json则转换为字典
|
||||
if cfg_value == 'true':
|
||||
cfg_value = True
|
||||
elif cfg_value == 'false':
|
||||
cfg_value = False
|
||||
elif cfg_value.isdigit():
|
||||
cfg_value = int(cfg_value)
|
||||
elif cfg_value.startswith('{') and cfg_value.endswith('}'):
|
||||
cfg_value = json.loads(cfg_value)
|
||||
else:
|
||||
try:
|
||||
cfg_value = float(cfg_value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 检查类型是否匹配
|
||||
if isinstance(getattr(config, cfg_name), type(cfg_value)):
|
||||
setattr(config, cfg_name, cfg_value)
|
||||
pkg.utils.context.set_config(config)
|
||||
reply = ["[bot]配置项{}修改成功".format(cfg_name)]
|
||||
else:
|
||||
reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
|
||||
|
||||
else:
|
||||
reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
|
||||
else:
|
||||
reply = ["[bot]err:未知的指令或权限不足: "+cmd]
|
||||
except Exception as e:
|
||||
|
|
Loading…
Reference in New Issue
Block a user