QChatGPT/pkg/openai/session.py
2022-12-08 13:22:54 +08:00

111 lines
3.0 KiB
Python

import time
import pkg.openai.manager
import pkg.database.manager
sessions = {}
class SessionOfflineStatus:
ON_GOING = 'on_going'
EXPLICITLY_CLOSED = 'explicitly_closed'
def load_sessions():
global sessions
db_inst = pkg.database.manager.get_inst()
session_data = db_inst.load_valid_sessions()
for session_name in session_data:
temp_session = Session(session_name)
temp_session.name = session_name
temp_session.create_timestamp = session_data[session_name]['create_timestamp']
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
temp_session.prompt = session_data[session_name]['prompt']
sessions[session_name] = temp_session
def get_session(session_name: str):
global sessions
if session_name not in sessions:
sessions[session_name] = Session(session_name)
return sessions[session_name]
def dump_session(session_name: str):
global sessions
if session_name in sessions:
assert isinstance(sessions[session_name], Session)
sessions[session_name].persistence()
del sessions[session_name]
# 通用的OpenAI API交互session
class Session:
name = ''
prompt = ''
user_name = 'You'
bot_name = 'Bot'
create_timestamp = 0
last_interact_timestamp = 0
def __init__(self, name: str):
self.name = name
self.create_timestamp = int(time.time())
# 请求回复
# 这个函数是阻塞的
def append(self, text: str) -> str:
self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':'
self.last_interact_timestamp = int(time.time())
# 向API请求补全
response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name + ':')
# print(response)
# 处理回复
res_test = response["choices"][0]["text"]
res_ans = res_test
# 去除开头可能的提示
res_ans_spt = res_test.split("\n\n")
if len(res_ans_spt) > 1:
del (res_ans_spt[0])
res_ans = '\n\n'.join(res_ans_spt)
self.prompt += "{}".format(res_ans) + '\n'
return res_ans
def persistence(self):
db_inst = pkg.database.manager.get_inst()
name_spt = self.name.split('_')
subject_type = name_spt[0]
subject_number = int(name_spt[1])
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
self.prompt)
def reset(self, explicit: bool = False):
if self.prompt != '':
self.persistence()
if explicit:
pkg.database.manager.get_inst().explicit_close_session(self.name, self.create_timestamp)
self.prompt = ''
self.create_timestamp = int(time.time())
self.last_interact_timestamp = 0
def last_session(self):
pass
def next_session(self):
pass