feat: 运行时原动态引入config的地方现在均使用初始化时导入的config对象

This commit is contained in:
Rock Chin 2023-01-04 17:09:57 +08:00
parent b318f6d4f0
commit 95ad911a6c
8 changed files with 59 additions and 20 deletions

44
main.py
View File

@ -1,3 +1,4 @@
import importlib
import os import os
import shutil import shutil
import threading import threading
@ -33,8 +34,10 @@ def init_db():
database.initialize_database() database.initialize_database()
known_exception_caught = False known_exception_caught = False
def main(first_time_init=False): def main(first_time_init=False):
global known_exception_caught global known_exception_caught
@ -43,12 +46,11 @@ def main(first_time_init=False):
# 导入config.py # 导入config.py
assert os.path.exists('config.py') assert os.path.exists('config.py')
# 检查是否设置了管理员 config = importlib.import_module('config')
import config
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
import pkg.utils.context import pkg.utils.context
pkg.utils.context.set_config(config)
if pkg.utils.context.context['logger_handler'] is not None: if pkg.utils.context.context['logger_handler'] is not None:
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler']) logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
@ -69,6 +71,10 @@ def main(first_time_init=False):
)) ))
logging.getLogger().addHandler(sh) logging.getLogger().addHandler(sh)
# 检查是否设置了管理员
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
import pkg.openai.manager import pkg.openai.manager
import pkg.database.manager import pkg.database.manager
import pkg.openai.session import pkg.openai.session
@ -98,35 +104,42 @@ def main(first_time_init=False):
qqbot.bot.run() qqbot.bot.run()
except TypeError as e: except TypeError as e:
if str(e).__contains__("argument 'debug'"): if str(e).__contains__("argument 'debug'"):
logging.error("连接bot失败:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/82".format(e)) logging.error(
"连接bot失败:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/82".format(e))
known_exception_caught = True known_exception_caught = True
elif str(e).__contains__("As of 3.10, the *loop*"): elif str(e).__contains__("As of 3.10, the *loop*"):
logging.error("Websockets版本过低:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/5".format(e)) logging.error(
"Websockets版本过低:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/5".format(e))
known_exception_caught = True known_exception_caught = True
except websockets.exceptions.InvalidStatus as e: except websockets.exceptions.InvalidStatus as e:
logging.error("mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(e)) logging.error(
"mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(
e))
known_exception_caught = True known_exception_caught = True
except mirai.exceptions.NetworkError as e: except mirai.exceptions.NetworkError as e:
logging.error("连接mirai-api-http失败:{}, 请检查是否已按照文档启动mirai".format(e)) logging.error("连接mirai-api-http失败:{}, 请检查是否已按照文档启动mirai".format(e))
known_exception_caught = True known_exception_caught = True
except Exception as e: except Exception as e:
if str(e).__contains__("HTTP 404"): if str(e).__contains__("404"):
logging.error("mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(e)) logging.error(
"mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(
e))
known_exception_caught = True known_exception_caught = True
else: else:
logging.error("捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/issues 查找或提issue".format(e)) logging.error(
"捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/issues 查找或提issue".format(e))
known_exception_caught = True known_exception_caught = True
raise e raise e
qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True) qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True)
qq_bot_thread.start() qq_bot_thread.start()
finally: finally:
time.sleep(10) time.sleep(12)
if first_time_init: if first_time_init:
if not known_exception_caught: if not known_exception_caught:
logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 ' logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 '
'https://github.com/RockChinQ/QChatGPT/issues/37') 'https://github.com/RockChinQ/QChatGPT/issues/37')
else: else:
sys.exit(1) sys.exit(1)
else: else:
@ -177,10 +190,15 @@ if __name__ == '__main__':
elif len(sys.argv) > 1 and sys.argv[1] == 'update': elif len(sys.argv) > 1 and sys.argv[1] == 'update':
try: try:
from dulwich import porcelain from dulwich import porcelain
repo = porcelain.open_repo('.') repo = porcelain.open_repo('.')
porcelain.pull(repo) porcelain.pull(repo)
except ModuleNotFoundError: except ModuleNotFoundError:
print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
sys.exit(0) sys.exit(0)
# import pkg.utils.configmgr
#
# pkg.utils.configmgr.set_config_and_reload("quote_origin", False)
main(True) main(True)

View File

@ -6,7 +6,6 @@ from sqlite3 import Cursor
import sqlite3 import sqlite3
import config
import pkg.utils.context import pkg.utils.context
@ -25,7 +24,6 @@ class DatabaseManager:
# 连接到数据库文件 # 连接到数据库文件
def reconnect(self): def reconnect(self):
self.conn = sqlite3.connect('database.db', check_same_thread=False) self.conn = sqlite3.connect('database.db', check_same_thread=False)
# self.conn.isolation_level = None
self.cursor = self.conn.cursor() self.cursor = self.conn.cursor()
def close(self): def close(self):
@ -127,6 +125,7 @@ class DatabaseManager:
# 从数据库加载还没过期的session数据 # 从数据库加载还没过期的session数据
def load_valid_sessions(self) -> dict: def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session # 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config()
self.execute(""" self.execute("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
from `sessions` where `last_interact_timestamp` > {} from `sessions` where `last_interact_timestamp` > {}

View File

@ -5,7 +5,6 @@ import logging
import pkg.database.manager import pkg.database.manager
import pkg.qqbot.manager import pkg.qqbot.manager
import pkg.utils.context import pkg.utils.context
import config
class KeysManager: class KeysManager:
@ -34,6 +33,8 @@ class KeysManager:
def __init__(self, api_key): def __init__(self, api_key):
# if hasattr(config, 'api_key_usage_threshold'): # if hasattr(config, 'api_key_usage_threshold'):
# self.api_key_usage_threshold = config.api_key_usage_threshold # self.api_key_usage_threshold = config.api_key_usage_threshold
config = pkg.utils.context.get_config()
if hasattr(config, 'api_key_fee_threshold'): if hasattr(config, 'api_key_fee_threshold'):
self.api_key_fee_threshold = config.api_key_fee_threshold self.api_key_fee_threshold = config.api_key_fee_threshold
self.load_fee() self.load_fee()
@ -108,6 +109,7 @@ class KeysManager:
self.fee[md5] += fee self.fee[md5] += fee
config = pkg.utils.context.get_config()
if self.fee[md5] >= self.api_key_fee_threshold and \ if self.fee[md5] >= self.api_key_fee_threshold and \
hasattr(config, 'auto_switch_api_key') and config.auto_switch_api_key: hasattr(config, 'auto_switch_api_key') and config.auto_switch_api_key:
switch_result, key_name = self.auto_switch() switch_result, key_name = self.auto_switch()

View File

@ -2,8 +2,6 @@ import logging
import openai import openai
import config
import pkg.openai.keymgr import pkg.openai.keymgr
import pkg.openai.pricing as pricing import pkg.openai.pricing as pricing
import pkg.utils.context import pkg.utils.context
@ -34,6 +32,7 @@ class OpenAIInteract:
# 请求OpenAI Completion # 请求OpenAI Completion
def request_completion(self, prompt, stop): def request_completion(self, prompt, stop):
config = pkg.utils.context.get_config()
response = openai.Completion.create( response = openai.Completion.create(
prompt=prompt, prompt=prompt,
stop=stop, stop=stop,
@ -53,6 +52,7 @@ class OpenAIInteract:
def request_image(self, prompt): def request_image(self, prompt):
config = pkg.utils.context.get_config()
params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params
response = openai.Image.create( response = openai.Image.create(

View File

@ -2,7 +2,6 @@ import logging
import threading import threading
import time import time
import config
import pkg.openai.manager import pkg.openai.manager
import pkg.database.manager import pkg.database.manager
import pkg.utils.context import pkg.utils.context
@ -54,6 +53,7 @@ def dump_session(session_name: str):
# 从配置文件获取会话预设信息 # 从配置文件获取会话预设信息
def get_default_prompt(): def get_default_prompt():
import config
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You' user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot' bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \ return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \
@ -85,6 +85,8 @@ class Session:
prompt = get_default_prompt() prompt = get_default_prompt()
import config
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You' user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot' bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
@ -130,6 +132,8 @@ class Session:
# 不是此session已更换退出 # 不是此session已更换退出
if self.create_timestamp != create_timestamp or self not in sessions.values(): if self.create_timestamp != create_timestamp or self not in sessions.values():
return return
config = pkg.utils.context.get_config()
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time: if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
logging.info('session {} 已过期'.format(self.name)) logging.info('session {} 已过期'.format(self.name))
self.reset(expired=True, schedule_new=False) self.reset(expired=True, schedule_new=False)
@ -144,6 +148,7 @@ class Session:
self.last_interact_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time())
# max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7 # max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7
config = pkg.utils.context.get_config()
max_rounds = 1000 # 不再限制回合数 max_rounds = 1000 # 不再限制回合数
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024 max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024

View File

@ -7,7 +7,6 @@ import mirai.models.bus
from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \ from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
FriendMessage, Image FriendMessage, Image
import config
import pkg.openai.session import pkg.openai.session
import pkg.openai.manager import pkg.openai.manager
from func_timeout import FunctionTimedOut from func_timeout import FunctionTimedOut
@ -26,6 +25,7 @@ def go(func, args=()):
# 检查消息是否符合泛响应匹配机制 # 检查消息是否符合泛响应匹配机制
def check_response_rule(text: str) -> (bool, str): def check_response_rule(text: str) -> (bool, str):
config = pkg.utils.context.get_config()
if not hasattr(config, 'response_rules'): if not hasattr(config, 'response_rules'):
return False, '' return False, ''
@ -60,6 +60,7 @@ class QQBotManager:
self.timeout = timeout self.timeout = timeout
self.retry = retry self.retry = retry
config = pkg.utils.context.get_config()
if os.path.exists("sensitive.json") \ if os.path.exists("sensitive.json") \
and config.sensitive_word_filter is not None \ and config.sensitive_word_filter is not None \
and config.sensitive_word_filter: and config.sensitive_word_filter:
@ -134,6 +135,7 @@ class QQBotManager:
self.bot = bot self.bot = bot
def send(self, event, msg, check_quote=True): def send(self, event, msg, check_quote=True):
config = pkg.utils.context.get_config()
asyncio.run( asyncio.run(
self.bot.send(event, msg, quote=True if hasattr(config, self.bot.send(event, msg, quote=True if hasattr(config,
"quote_origin") and config.quote_origin and check_quote else False)) "quote_origin") and config.quote_origin and check_quote else False))
@ -216,6 +218,7 @@ class QQBotManager:
# 通知系统管理员 # 通知系统管理员
def notify_admin(self, message: str): def notify_admin(self, message: str):
config = pkg.utils.context.get_config()
if hasattr(config, "admin_qq") and config.admin_qq != 0: if hasattr(config, "admin_qq") and config.admin_qq != 0:
logging.info("通知管理员:{}".format(message)) logging.info("通知管理员:{}".format(message))
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message)) send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message))

View File

@ -9,6 +9,9 @@ import openai
from mirai import Image, MessageChain from mirai import Image, MessageChain
# 这里不使用动态引入config
# 因为在这里动态引入会卡死程序
# 而此模块静态引用config与动态引入的表现一致
import config import config
import pkg.openai.session import pkg.openai.session

View File

@ -9,9 +9,18 @@ context = {
'qqbot.manager.QQBotManager': None, 'qqbot.manager.QQBotManager': None,
}, },
'logger_handler': None, 'logger_handler': None,
'config': None,
} }
def set_config(inst):
context['config'] = inst
def get_config():
return context['config']
def set_database_manager(inst): def set_database_manager(inst):
context['inst']['database.manager.DatabaseManager'] = inst context['inst']['database.manager.DatabaseManager'] = inst