mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
feat: 运行时原动态引入config的地方现在均使用初始化时导入的config对象
This commit is contained in:
parent
b318f6d4f0
commit
95ad911a6c
44
main.py
44
main.py
|
@ -1,3 +1,4 @@
|
|||
import importlib
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
|
@ -33,8 +34,10 @@ def init_db():
|
|||
|
||||
database.initialize_database()
|
||||
|
||||
|
||||
known_exception_caught = False
|
||||
|
||||
|
||||
def main(first_time_init=False):
|
||||
global known_exception_caught
|
||||
|
||||
|
@ -43,12 +46,11 @@ def main(first_time_init=False):
|
|||
# 导入config.py
|
||||
assert os.path.exists('config.py')
|
||||
|
||||
# 检查是否设置了管理员
|
||||
import config
|
||||
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
|
||||
logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
|
||||
config = importlib.import_module('config')
|
||||
|
||||
import pkg.utils.context
|
||||
pkg.utils.context.set_config(config)
|
||||
|
||||
if pkg.utils.context.context['logger_handler'] is not None:
|
||||
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
|
||||
|
||||
|
@ -69,6 +71,10 @@ def main(first_time_init=False):
|
|||
))
|
||||
logging.getLogger().addHandler(sh)
|
||||
|
||||
# 检查是否设置了管理员
|
||||
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
|
||||
logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
|
||||
|
||||
import pkg.openai.manager
|
||||
import pkg.database.manager
|
||||
import pkg.openai.session
|
||||
|
@ -98,35 +104,42 @@ def main(first_time_init=False):
|
|||
qqbot.bot.run()
|
||||
except TypeError as e:
|
||||
if str(e).__contains__("argument 'debug'"):
|
||||
logging.error("连接bot失败:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/82".format(e))
|
||||
logging.error(
|
||||
"连接bot失败:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/82".format(e))
|
||||
known_exception_caught = True
|
||||
elif str(e).__contains__("As of 3.10, the *loop*"):
|
||||
logging.error("Websockets版本过低:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/5".format(e))
|
||||
logging.error(
|
||||
"Websockets版本过低:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/5".format(e))
|
||||
known_exception_caught = True
|
||||
|
||||
except websockets.exceptions.InvalidStatus as e:
|
||||
logging.error("mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(e))
|
||||
logging.error(
|
||||
"mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(
|
||||
e))
|
||||
known_exception_caught = True
|
||||
except mirai.exceptions.NetworkError as e:
|
||||
logging.error("连接mirai-api-http失败:{}, 请检查是否已按照文档启动mirai".format(e))
|
||||
known_exception_caught = True
|
||||
except Exception as e:
|
||||
if str(e).__contains__("HTTP 404"):
|
||||
logging.error("mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(e))
|
||||
if str(e).__contains__("404"):
|
||||
logging.error(
|
||||
"mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(
|
||||
e))
|
||||
known_exception_caught = True
|
||||
else:
|
||||
logging.error("捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/issues 查找或提issue".format(e))
|
||||
logging.error(
|
||||
"捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/issues 查找或提issue".format(e))
|
||||
known_exception_caught = True
|
||||
raise e
|
||||
|
||||
qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True)
|
||||
qq_bot_thread.start()
|
||||
finally:
|
||||
time.sleep(10)
|
||||
time.sleep(12)
|
||||
if first_time_init:
|
||||
if not known_exception_caught:
|
||||
logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 '
|
||||
'https://github.com/RockChinQ/QChatGPT/issues/37')
|
||||
logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 '
|
||||
'https://github.com/RockChinQ/QChatGPT/issues/37')
|
||||
else:
|
||||
sys.exit(1)
|
||||
else:
|
||||
|
@ -177,10 +190,15 @@ if __name__ == '__main__':
|
|||
elif len(sys.argv) > 1 and sys.argv[1] == 'update':
|
||||
try:
|
||||
from dulwich import porcelain
|
||||
|
||||
repo = porcelain.open_repo('.')
|
||||
porcelain.pull(repo)
|
||||
except ModuleNotFoundError:
|
||||
print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
|
||||
sys.exit(0)
|
||||
|
||||
# import pkg.utils.configmgr
|
||||
#
|
||||
# pkg.utils.configmgr.set_config_and_reload("quote_origin", False)
|
||||
|
||||
main(True)
|
||||
|
|
|
@ -6,7 +6,6 @@ from sqlite3 import Cursor
|
|||
|
||||
import sqlite3
|
||||
|
||||
import config
|
||||
import pkg.utils.context
|
||||
|
||||
|
||||
|
@ -25,7 +24,6 @@ class DatabaseManager:
|
|||
# 连接到数据库文件
|
||||
def reconnect(self):
|
||||
self.conn = sqlite3.connect('database.db', check_same_thread=False)
|
||||
# self.conn.isolation_level = None
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def close(self):
|
||||
|
@ -127,6 +125,7 @@ class DatabaseManager:
|
|||
# 从数据库加载还没过期的session数据
|
||||
def load_valid_sessions(self) -> dict:
|
||||
# 从数据库中加载所有还没过期的session
|
||||
config = pkg.utils.context.get_config()
|
||||
self.execute("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
||||
from `sessions` where `last_interact_timestamp` > {}
|
||||
|
|
|
@ -5,7 +5,6 @@ import logging
|
|||
import pkg.database.manager
|
||||
import pkg.qqbot.manager
|
||||
import pkg.utils.context
|
||||
import config
|
||||
|
||||
|
||||
class KeysManager:
|
||||
|
@ -34,6 +33,8 @@ class KeysManager:
|
|||
def __init__(self, api_key):
|
||||
# if hasattr(config, 'api_key_usage_threshold'):
|
||||
# self.api_key_usage_threshold = config.api_key_usage_threshold
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
if hasattr(config, 'api_key_fee_threshold'):
|
||||
self.api_key_fee_threshold = config.api_key_fee_threshold
|
||||
self.load_fee()
|
||||
|
@ -108,6 +109,7 @@ class KeysManager:
|
|||
|
||||
self.fee[md5] += fee
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
if self.fee[md5] >= self.api_key_fee_threshold and \
|
||||
hasattr(config, 'auto_switch_api_key') and config.auto_switch_api_key:
|
||||
switch_result, key_name = self.auto_switch()
|
||||
|
|
|
@ -2,8 +2,6 @@ import logging
|
|||
|
||||
import openai
|
||||
|
||||
import config
|
||||
|
||||
import pkg.openai.keymgr
|
||||
import pkg.openai.pricing as pricing
|
||||
import pkg.utils.context
|
||||
|
@ -34,6 +32,7 @@ class OpenAIInteract:
|
|||
|
||||
# 请求OpenAI Completion
|
||||
def request_completion(self, prompt, stop):
|
||||
config = pkg.utils.context.get_config()
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
|
@ -53,6 +52,7 @@ class OpenAIInteract:
|
|||
|
||||
def request_image(self, prompt):
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params
|
||||
|
||||
response = openai.Image.create(
|
||||
|
|
|
@ -2,7 +2,6 @@ import logging
|
|||
import threading
|
||||
import time
|
||||
|
||||
import config
|
||||
import pkg.openai.manager
|
||||
import pkg.database.manager
|
||||
import pkg.utils.context
|
||||
|
@ -54,6 +53,7 @@ def dump_session(session_name: str):
|
|||
|
||||
# 从配置文件获取会话预设信息
|
||||
def get_default_prompt():
|
||||
import config
|
||||
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
|
||||
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
|
||||
return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \
|
||||
|
@ -85,6 +85,8 @@ class Session:
|
|||
|
||||
prompt = get_default_prompt()
|
||||
|
||||
import config
|
||||
|
||||
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
|
||||
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
|
||||
|
||||
|
@ -130,6 +132,8 @@ class Session:
|
|||
# 不是此session已更换,退出
|
||||
if self.create_timestamp != create_timestamp or self not in sessions.values():
|
||||
return
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
|
||||
logging.info('session {} 已过期'.format(self.name))
|
||||
self.reset(expired=True, schedule_new=False)
|
||||
|
@ -144,6 +148,7 @@ class Session:
|
|||
self.last_interact_timestamp = int(time.time())
|
||||
|
||||
# max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7
|
||||
config = pkg.utils.context.get_config()
|
||||
max_rounds = 1000 # 不再限制回合数
|
||||
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ import mirai.models.bus
|
|||
from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
||||
FriendMessage, Image
|
||||
|
||||
import config
|
||||
import pkg.openai.session
|
||||
import pkg.openai.manager
|
||||
from func_timeout import FunctionTimedOut
|
||||
|
@ -26,6 +25,7 @@ def go(func, args=()):
|
|||
|
||||
# 检查消息是否符合泛响应匹配机制
|
||||
def check_response_rule(text: str) -> (bool, str):
|
||||
config = pkg.utils.context.get_config()
|
||||
if not hasattr(config, 'response_rules'):
|
||||
return False, ''
|
||||
|
||||
|
@ -60,6 +60,7 @@ class QQBotManager:
|
|||
self.timeout = timeout
|
||||
self.retry = retry
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
if os.path.exists("sensitive.json") \
|
||||
and config.sensitive_word_filter is not None \
|
||||
and config.sensitive_word_filter:
|
||||
|
@ -134,6 +135,7 @@ class QQBotManager:
|
|||
self.bot = bot
|
||||
|
||||
def send(self, event, msg, check_quote=True):
|
||||
config = pkg.utils.context.get_config()
|
||||
asyncio.run(
|
||||
self.bot.send(event, msg, quote=True if hasattr(config,
|
||||
"quote_origin") and config.quote_origin and check_quote else False))
|
||||
|
@ -216,6 +218,7 @@ class QQBotManager:
|
|||
|
||||
# 通知系统管理员
|
||||
def notify_admin(self, message: str):
|
||||
config = pkg.utils.context.get_config()
|
||||
if hasattr(config, "admin_qq") and config.admin_qq != 0:
|
||||
logging.info("通知管理员:{}".format(message))
|
||||
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message))
|
||||
|
|
|
@ -9,6 +9,9 @@ import openai
|
|||
|
||||
from mirai import Image, MessageChain
|
||||
|
||||
# 这里不使用动态引入config
|
||||
# 因为在这里动态引入会卡死程序
|
||||
# 而此模块静态引用config与动态引入的表现一致
|
||||
import config
|
||||
|
||||
import pkg.openai.session
|
||||
|
|
|
@ -9,9 +9,18 @@ context = {
|
|||
'qqbot.manager.QQBotManager': None,
|
||||
},
|
||||
'logger_handler': None,
|
||||
'config': None,
|
||||
}
|
||||
|
||||
|
||||
def set_config(inst):
|
||||
context['config'] = inst
|
||||
|
||||
|
||||
def get_config():
|
||||
return context['config']
|
||||
|
||||
|
||||
def set_database_manager(inst):
|
||||
context['inst']['database.manager.DatabaseManager'] = inst
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user