mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
feat: 运行时原动态引入config的地方现在均使用初始化时导入的config对象
This commit is contained in:
parent
b318f6d4f0
commit
95ad911a6c
44
main.py
44
main.py
|
@ -1,3 +1,4 @@
|
||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import threading
|
import threading
|
||||||
|
@ -33,8 +34,10 @@ def init_db():
|
||||||
|
|
||||||
database.initialize_database()
|
database.initialize_database()
|
||||||
|
|
||||||
|
|
||||||
known_exception_caught = False
|
known_exception_caught = False
|
||||||
|
|
||||||
|
|
||||||
def main(first_time_init=False):
|
def main(first_time_init=False):
|
||||||
global known_exception_caught
|
global known_exception_caught
|
||||||
|
|
||||||
|
@ -43,12 +46,11 @@ def main(first_time_init=False):
|
||||||
# 导入config.py
|
# 导入config.py
|
||||||
assert os.path.exists('config.py')
|
assert os.path.exists('config.py')
|
||||||
|
|
||||||
# 检查是否设置了管理员
|
config = importlib.import_module('config')
|
||||||
import config
|
|
||||||
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
|
|
||||||
logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
|
|
||||||
|
|
||||||
import pkg.utils.context
|
import pkg.utils.context
|
||||||
|
pkg.utils.context.set_config(config)
|
||||||
|
|
||||||
if pkg.utils.context.context['logger_handler'] is not None:
|
if pkg.utils.context.context['logger_handler'] is not None:
|
||||||
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
|
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
|
||||||
|
|
||||||
|
@ -69,6 +71,10 @@ def main(first_time_init=False):
|
||||||
))
|
))
|
||||||
logging.getLogger().addHandler(sh)
|
logging.getLogger().addHandler(sh)
|
||||||
|
|
||||||
|
# 检查是否设置了管理员
|
||||||
|
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
|
||||||
|
logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
|
||||||
|
|
||||||
import pkg.openai.manager
|
import pkg.openai.manager
|
||||||
import pkg.database.manager
|
import pkg.database.manager
|
||||||
import pkg.openai.session
|
import pkg.openai.session
|
||||||
|
@ -98,35 +104,42 @@ def main(first_time_init=False):
|
||||||
qqbot.bot.run()
|
qqbot.bot.run()
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
if str(e).__contains__("argument 'debug'"):
|
if str(e).__contains__("argument 'debug'"):
|
||||||
logging.error("连接bot失败:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/82".format(e))
|
logging.error(
|
||||||
|
"连接bot失败:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/82".format(e))
|
||||||
known_exception_caught = True
|
known_exception_caught = True
|
||||||
elif str(e).__contains__("As of 3.10, the *loop*"):
|
elif str(e).__contains__("As of 3.10, the *loop*"):
|
||||||
logging.error("Websockets版本过低:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/5".format(e))
|
logging.error(
|
||||||
|
"Websockets版本过低:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/5".format(e))
|
||||||
known_exception_caught = True
|
known_exception_caught = True
|
||||||
|
|
||||||
except websockets.exceptions.InvalidStatus as e:
|
except websockets.exceptions.InvalidStatus as e:
|
||||||
logging.error("mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(e))
|
logging.error(
|
||||||
|
"mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(
|
||||||
|
e))
|
||||||
known_exception_caught = True
|
known_exception_caught = True
|
||||||
except mirai.exceptions.NetworkError as e:
|
except mirai.exceptions.NetworkError as e:
|
||||||
logging.error("连接mirai-api-http失败:{}, 请检查是否已按照文档启动mirai".format(e))
|
logging.error("连接mirai-api-http失败:{}, 请检查是否已按照文档启动mirai".format(e))
|
||||||
known_exception_caught = True
|
known_exception_caught = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).__contains__("HTTP 404"):
|
if str(e).__contains__("404"):
|
||||||
logging.error("mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(e))
|
logging.error(
|
||||||
|
"mirai-api-http端口无法使用:{}, 请查看 https://github.com/RockChinQ/QChatGPT/issues/22".format(
|
||||||
|
e))
|
||||||
known_exception_caught = True
|
known_exception_caught = True
|
||||||
else:
|
else:
|
||||||
logging.error("捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/issues 查找或提issue".format(e))
|
logging.error(
|
||||||
|
"捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/issues 查找或提issue".format(e))
|
||||||
known_exception_caught = True
|
known_exception_caught = True
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True)
|
qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True)
|
||||||
qq_bot_thread.start()
|
qq_bot_thread.start()
|
||||||
finally:
|
finally:
|
||||||
time.sleep(10)
|
time.sleep(12)
|
||||||
if first_time_init:
|
if first_time_init:
|
||||||
if not known_exception_caught:
|
if not known_exception_caught:
|
||||||
logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 '
|
logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 '
|
||||||
'https://github.com/RockChinQ/QChatGPT/issues/37')
|
'https://github.com/RockChinQ/QChatGPT/issues/37')
|
||||||
else:
|
else:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
|
@ -177,10 +190,15 @@ if __name__ == '__main__':
|
||||||
elif len(sys.argv) > 1 and sys.argv[1] == 'update':
|
elif len(sys.argv) > 1 and sys.argv[1] == 'update':
|
||||||
try:
|
try:
|
||||||
from dulwich import porcelain
|
from dulwich import porcelain
|
||||||
|
|
||||||
repo = porcelain.open_repo('.')
|
repo = porcelain.open_repo('.')
|
||||||
porcelain.pull(repo)
|
porcelain.pull(repo)
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
|
print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
# import pkg.utils.configmgr
|
||||||
|
#
|
||||||
|
# pkg.utils.configmgr.set_config_and_reload("quote_origin", False)
|
||||||
|
|
||||||
main(True)
|
main(True)
|
||||||
|
|
|
@ -6,7 +6,6 @@ from sqlite3 import Cursor
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
import config
|
|
||||||
import pkg.utils.context
|
import pkg.utils.context
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +24,6 @@ class DatabaseManager:
|
||||||
# 连接到数据库文件
|
# 连接到数据库文件
|
||||||
def reconnect(self):
|
def reconnect(self):
|
||||||
self.conn = sqlite3.connect('database.db', check_same_thread=False)
|
self.conn = sqlite3.connect('database.db', check_same_thread=False)
|
||||||
# self.conn.isolation_level = None
|
|
||||||
self.cursor = self.conn.cursor()
|
self.cursor = self.conn.cursor()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -127,6 +125,7 @@ class DatabaseManager:
|
||||||
# 从数据库加载还没过期的session数据
|
# 从数据库加载还没过期的session数据
|
||||||
def load_valid_sessions(self) -> dict:
|
def load_valid_sessions(self) -> dict:
|
||||||
# 从数据库中加载所有还没过期的session
|
# 从数据库中加载所有还没过期的session
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
self.execute("""
|
self.execute("""
|
||||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
||||||
from `sessions` where `last_interact_timestamp` > {}
|
from `sessions` where `last_interact_timestamp` > {}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import logging
|
||||||
import pkg.database.manager
|
import pkg.database.manager
|
||||||
import pkg.qqbot.manager
|
import pkg.qqbot.manager
|
||||||
import pkg.utils.context
|
import pkg.utils.context
|
||||||
import config
|
|
||||||
|
|
||||||
|
|
||||||
class KeysManager:
|
class KeysManager:
|
||||||
|
@ -34,6 +33,8 @@ class KeysManager:
|
||||||
def __init__(self, api_key):
|
def __init__(self, api_key):
|
||||||
# if hasattr(config, 'api_key_usage_threshold'):
|
# if hasattr(config, 'api_key_usage_threshold'):
|
||||||
# self.api_key_usage_threshold = config.api_key_usage_threshold
|
# self.api_key_usage_threshold = config.api_key_usage_threshold
|
||||||
|
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
if hasattr(config, 'api_key_fee_threshold'):
|
if hasattr(config, 'api_key_fee_threshold'):
|
||||||
self.api_key_fee_threshold = config.api_key_fee_threshold
|
self.api_key_fee_threshold = config.api_key_fee_threshold
|
||||||
self.load_fee()
|
self.load_fee()
|
||||||
|
@ -108,6 +109,7 @@ class KeysManager:
|
||||||
|
|
||||||
self.fee[md5] += fee
|
self.fee[md5] += fee
|
||||||
|
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
if self.fee[md5] >= self.api_key_fee_threshold and \
|
if self.fee[md5] >= self.api_key_fee_threshold and \
|
||||||
hasattr(config, 'auto_switch_api_key') and config.auto_switch_api_key:
|
hasattr(config, 'auto_switch_api_key') and config.auto_switch_api_key:
|
||||||
switch_result, key_name = self.auto_switch()
|
switch_result, key_name = self.auto_switch()
|
||||||
|
|
|
@ -2,8 +2,6 @@ import logging
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
import config
|
|
||||||
|
|
||||||
import pkg.openai.keymgr
|
import pkg.openai.keymgr
|
||||||
import pkg.openai.pricing as pricing
|
import pkg.openai.pricing as pricing
|
||||||
import pkg.utils.context
|
import pkg.utils.context
|
||||||
|
@ -34,6 +32,7 @@ class OpenAIInteract:
|
||||||
|
|
||||||
# 请求OpenAI Completion
|
# 请求OpenAI Completion
|
||||||
def request_completion(self, prompt, stop):
|
def request_completion(self, prompt, stop):
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
response = openai.Completion.create(
|
response = openai.Completion.create(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
@ -53,6 +52,7 @@ class OpenAIInteract:
|
||||||
|
|
||||||
def request_image(self, prompt):
|
def request_image(self, prompt):
|
||||||
|
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params
|
params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params
|
||||||
|
|
||||||
response = openai.Image.create(
|
response = openai.Image.create(
|
||||||
|
|
|
@ -2,7 +2,6 @@ import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import config
|
|
||||||
import pkg.openai.manager
|
import pkg.openai.manager
|
||||||
import pkg.database.manager
|
import pkg.database.manager
|
||||||
import pkg.utils.context
|
import pkg.utils.context
|
||||||
|
@ -54,6 +53,7 @@ def dump_session(session_name: str):
|
||||||
|
|
||||||
# 从配置文件获取会话预设信息
|
# 从配置文件获取会话预设信息
|
||||||
def get_default_prompt():
|
def get_default_prompt():
|
||||||
|
import config
|
||||||
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
|
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
|
||||||
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
|
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
|
||||||
return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \
|
return user_name + ":{}\n".format(config.default_prompt if hasattr(config, 'default_prompt') \
|
||||||
|
@ -85,6 +85,8 @@ class Session:
|
||||||
|
|
||||||
prompt = get_default_prompt()
|
prompt = get_default_prompt()
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
|
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
|
||||||
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
|
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
|
||||||
|
|
||||||
|
@ -130,6 +132,8 @@ class Session:
|
||||||
# 不是此session已更换,退出
|
# 不是此session已更换,退出
|
||||||
if self.create_timestamp != create_timestamp or self not in sessions.values():
|
if self.create_timestamp != create_timestamp or self not in sessions.values():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
|
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
|
||||||
logging.info('session {} 已过期'.format(self.name))
|
logging.info('session {} 已过期'.format(self.name))
|
||||||
self.reset(expired=True, schedule_new=False)
|
self.reset(expired=True, schedule_new=False)
|
||||||
|
@ -144,6 +148,7 @@ class Session:
|
||||||
self.last_interact_timestamp = int(time.time())
|
self.last_interact_timestamp = int(time.time())
|
||||||
|
|
||||||
# max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7
|
# max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
max_rounds = 1000 # 不再限制回合数
|
max_rounds = 1000 # 不再限制回合数
|
||||||
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
|
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ import mirai.models.bus
|
||||||
from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
||||||
FriendMessage, Image
|
FriendMessage, Image
|
||||||
|
|
||||||
import config
|
|
||||||
import pkg.openai.session
|
import pkg.openai.session
|
||||||
import pkg.openai.manager
|
import pkg.openai.manager
|
||||||
from func_timeout import FunctionTimedOut
|
from func_timeout import FunctionTimedOut
|
||||||
|
@ -26,6 +25,7 @@ def go(func, args=()):
|
||||||
|
|
||||||
# 检查消息是否符合泛响应匹配机制
|
# 检查消息是否符合泛响应匹配机制
|
||||||
def check_response_rule(text: str) -> (bool, str):
|
def check_response_rule(text: str) -> (bool, str):
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
if not hasattr(config, 'response_rules'):
|
if not hasattr(config, 'response_rules'):
|
||||||
return False, ''
|
return False, ''
|
||||||
|
|
||||||
|
@ -60,6 +60,7 @@ class QQBotManager:
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.retry = retry
|
self.retry = retry
|
||||||
|
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
if os.path.exists("sensitive.json") \
|
if os.path.exists("sensitive.json") \
|
||||||
and config.sensitive_word_filter is not None \
|
and config.sensitive_word_filter is not None \
|
||||||
and config.sensitive_word_filter:
|
and config.sensitive_word_filter:
|
||||||
|
@ -134,6 +135,7 @@ class QQBotManager:
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
|
||||||
def send(self, event, msg, check_quote=True):
|
def send(self, event, msg, check_quote=True):
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
self.bot.send(event, msg, quote=True if hasattr(config,
|
self.bot.send(event, msg, quote=True if hasattr(config,
|
||||||
"quote_origin") and config.quote_origin and check_quote else False))
|
"quote_origin") and config.quote_origin and check_quote else False))
|
||||||
|
@ -216,6 +218,7 @@ class QQBotManager:
|
||||||
|
|
||||||
# 通知系统管理员
|
# 通知系统管理员
|
||||||
def notify_admin(self, message: str):
|
def notify_admin(self, message: str):
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
if hasattr(config, "admin_qq") and config.admin_qq != 0:
|
if hasattr(config, "admin_qq") and config.admin_qq != 0:
|
||||||
logging.info("通知管理员:{}".format(message))
|
logging.info("通知管理员:{}".format(message))
|
||||||
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message))
|
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message))
|
||||||
|
|
|
@ -9,6 +9,9 @@ import openai
|
||||||
|
|
||||||
from mirai import Image, MessageChain
|
from mirai import Image, MessageChain
|
||||||
|
|
||||||
|
# 这里不使用动态引入config
|
||||||
|
# 因为在这里动态引入会卡死程序
|
||||||
|
# 而此模块静态引用config与动态引入的表现一致
|
||||||
import config
|
import config
|
||||||
|
|
||||||
import pkg.openai.session
|
import pkg.openai.session
|
||||||
|
|
|
@ -9,9 +9,18 @@ context = {
|
||||||
'qqbot.manager.QQBotManager': None,
|
'qqbot.manager.QQBotManager': None,
|
||||||
},
|
},
|
||||||
'logger_handler': None,
|
'logger_handler': None,
|
||||||
|
'config': None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def set_config(inst):
|
||||||
|
context['config'] = inst
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
return context['config']
|
||||||
|
|
||||||
|
|
||||||
def set_database_manager(inst):
|
def set_database_manager(inst):
|
||||||
context['inst']['database.manager.DatabaseManager'] = inst
|
context['inst']['database.manager.DatabaseManager'] = inst
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user