doc: 添加代码注释

This commit is contained in:
Rock Chin 2022-12-11 16:10:12 +08:00
parent deffae25da
commit 76b60a781f
5 changed files with 27 additions and 6 deletions

View File

@ -21,6 +21,9 @@ openai_config = {
"api_key": "", "api_key": "",
} }
# 管理员QQ号用于接收报错等通知为0时不发送通知
admin_qq = 0
# OpenAI的completion API的参数 # OpenAI的completion API的参数
# 不了解的话请不要修改具体请查看OpenAI的文档 # 不了解的话请不要修改具体请查看OpenAI的文档
completion_api_params = { completion_api_params = {
@ -55,9 +58,6 @@ session_expire_time = 60 * 20
# 日志级别 # 日志级别
logging_level = logging.INFO logging_level = logging.INFO
# 管理员QQ号用于接收报错等通知
admin_qq = 0
# 定制帮助消息 # 定制帮助消息
help_message = """此机器人通过调用OpenAI的GPT-3大型语言模型生成回复不具有情感。 help_message = """此机器人通过调用OpenAI的GPT-3大型语言模型生成回复不具有情感。
你可以用自然语言与其交流回复的消息中[GPT]开头的为模型生成的语言[bot]开头的为程序提示 你可以用自然语言与其交流回复的消息中[GPT]开头的为模型生成的语言[bot]开头的为程序提示

View File

@ -8,7 +8,8 @@ import config
inst = None inst = None
# 数据库管理
# 为其他模块提供数据库操作接口
class DatabaseManager: class DatabaseManager:
conn = None conn = None
cursor = None cursor = None
@ -17,15 +18,16 @@ class DatabaseManager:
self.reconnect() self.reconnect()
global inst global inst
inst = self inst = self
# 连接到数据库文件
def reconnect(self): def reconnect(self):
self.conn = sqlite3.connect('database.db', check_same_thread=False) self.conn = sqlite3.connect('database.db', check_same_thread=False)
# self.conn.isolation_level = None # self.conn.isolation_level = None
self.cursor = self.conn.cursor() self.cursor = self.conn.cursor()
# 初始化数据库的函数
def initialize_database(self): def initialize_database(self):
self.cursor.execute(""" self.cursor.execute("""
create table if not exists `sessions` ( create table if not exists `sessions` (
@ -42,6 +44,7 @@ class DatabaseManager:
self.conn.commit() self.conn.commit()
print('Database initialized.') print('Database initialized.')
# session持久化
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str): last_interact_timestamp: int, prompt: str):
# 检查是否已经有了此name和create_timestamp的session # 检查是否已经有了此name和create_timestamp的session
@ -66,6 +69,7 @@ class DatabaseManager:
subject_number, create_timestamp)) subject_number, create_timestamp))
self.conn.commit() self.conn.commit()
# 显式关闭一个session
def explicit_close_session(self, session_name: str, create_timestamp: int): def explicit_close_session(self, session_name: str, create_timestamp: int):
self.cursor.execute(""" self.cursor.execute("""
update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {} update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {}
@ -78,13 +82,14 @@ class DatabaseManager:
""".format(session_name, create_timestamp)) """.format(session_name, create_timestamp))
self.conn.commit() self.conn.commit()
# 设置session为过期
def set_session_expired(self, session_name: str, create_timestamp: int): def set_session_expired(self, session_name: str, create_timestamp: int):
self.cursor.execute(""" self.cursor.execute("""
update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {} update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp)) """.format(session_name, create_timestamp))
self.conn.commit() self.conn.commit()
# 载还没过期的session数据 # 从数据库加载还没过期的session数据
def load_valid_sessions(self) -> dict: def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session # 从数据库中加载所有还没过期的session
self.cursor.execute(""" self.cursor.execute("""
@ -175,6 +180,7 @@ class DatabaseManager:
'prompt': prompt 'prompt': prompt
} }
# 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int): def list_history(self, session_name: str, capacity: int, page: int):
self.cursor.execute(""" self.cursor.execute("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`

View File

@ -3,6 +3,7 @@ import openai
inst = None inst = None
# 为其他模块提供与OpenAI交互的接口
class OpenAIInteract: class OpenAIInteract:
api_key = '' api_key = ''
api_params = {} api_params = {}
@ -16,6 +17,7 @@ class OpenAIInteract:
global inst global inst
inst = self inst = self
# 请求OpenAI Completion
def request_completion(self, prompt, stop): def request_completion(self, prompt, stop):
response = openai.Completion.create( response = openai.Completion.create(
prompt=prompt, prompt=prompt,

View File

@ -6,6 +6,7 @@ import config
import pkg.openai.manager import pkg.openai.manager
import pkg.database.manager import pkg.database.manager
# 运行时保存的所有session
sessions = {} sessions = {}
@ -14,6 +15,7 @@ class SessionOfflineStatus:
EXPLICITLY_CLOSED = 'explicitly_closed' EXPLICITLY_CLOSED = 'explicitly_closed'
# 从数据加载session
def load_sessions(): def load_sessions():
global sessions global sessions
@ -33,6 +35,7 @@ def load_sessions():
sessions[session_name] = temp_session sessions[session_name] = temp_session
# 获取指定名称的session如果不存在则创建一个新的
def get_session(session_name: str): def get_session(session_name: str):
global sessions global sessions
if session_name not in sessions: if session_name not in sessions:
@ -49,6 +52,8 @@ def dump_session(session_name: str):
# 通用的OpenAI API交互session # 通用的OpenAI API交互session
# session内部保留了对话的上下文
# 收到用户消息后将上下文提交给OpenAI API生成回复
class Session: class Session:
name = '' name = ''
@ -69,6 +74,7 @@ class Session:
self.last_interact_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time())
self.schedule() self.schedule()
# 设定检查session最后一次对话是否超过过期时间的计时器
def schedule(self): def schedule(self):
threading.Thread(target=self.expire_check_timer_loop, args=(self.create_timestamp,)).start() threading.Thread(target=self.expire_check_timer_loop, args=(self.create_timestamp,)).start()
@ -146,6 +152,7 @@ class Session:
logging.debug('cut_out: {}'.format(result)) logging.debug('cut_out: {}'.format(result))
return result return result
# 持久化session
def persistence(self): def persistence(self):
if self.prompt == '': if self.prompt == '':
return return
@ -160,6 +167,7 @@ class Session:
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
self.prompt) self.prompt)
# 重置session
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True): def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True):
if self.prompt != '': if self.prompt != '':
self.persistence() self.persistence()
@ -195,6 +203,7 @@ class Session:
just_switched = True just_switched = True
return self return self
# 切换到下一个session
def next_session(self): def next_session(self):
next_one = pkg.database.manager.get_inst().next_session(self.name, self.last_interact_timestamp) next_one = pkg.database.manager.get_inst().next_session(self.name, self.last_interact_timestamp)
if next_one is None: if next_one is None:

View File

@ -17,6 +17,7 @@ inst = None
processing = [] processing = []
# 控制QQ消息输入输出的类
class QQBotManager: class QQBotManager:
timeout = 60 timeout = 60
retry = 3 retry = 3
@ -159,6 +160,7 @@ class QQBotManager:
return reply return reply
# 私聊消息处理
async def on_person_message(self, event: MessageEvent): async def on_person_message(self, event: MessageEvent):
global processing global processing
if "person_{}".format(event.sender.id) in processing: if "person_{}".format(event.sender.id) in processing:
@ -192,6 +194,7 @@ class QQBotManager:
if reply != '': if reply != '':
return await self.bot.send(event, reply) return await self.bot.send(event, reply)
# 群消息处理
async def on_group_message(self, event: GroupMessage): async def on_group_message(self, event: GroupMessage):
global processing global processing
if "group_{}".format(event.group.id) in processing: if "group_{}".format(event.group.id) in processing:
@ -226,6 +229,7 @@ class QQBotManager:
if reply != '': if reply != '':
return await self.bot.send(event, reply) return await self.bot.send(event, reply)
# 通知系统管理员
def notify_admin(self, message: str): def notify_admin(self, message: str):
if config.admin_qq is not None and config.admin_qq != 0: if config.admin_qq is not None and config.admin_qq != 0:
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message)) send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message))