diff --git a/pkg/openai~/__init__.py b/pkg/openai~/__init__.py deleted file mode 100644 index e6a669c..0000000 --- a/pkg/openai~/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""OpenAI 接口处理及会话管理相关 -""" diff --git a/pkg/openai~/dprompt.py b/pkg/openai~/dprompt.py deleted file mode 100644 index 3aba31c..0000000 --- a/pkg/openai~/dprompt.py +++ /dev/null @@ -1,79 +0,0 @@ -# 多情景预设值管理 - -__current__ = "default" -"""当前默认使用的情景预设的名称 - -由管理员使用`!default <名称>`指令切换 -""" - -__prompts_from_files__ = {} -"""从文件中读取的情景预设值""" - - -def read_prompt_from_file() -> str: - """从文件读取预设值""" - # 读取prompts/目录下的所有文件,以文件名为键,文件内容为值 - # 保存在__prompts_from_files__中 - global __prompts_from_files__ - import os - - __prompts_from_files__ = {} - for file in os.listdir("prompts"): - with open(os.path.join("prompts", file), encoding="utf-8") as f: - __prompts_from_files__[file] = f.read() - - -def get_prompt_dict() -> dict: - """获取预设值字典""" - import config - default_prompt = config.default_prompt - if type(default_prompt) == str: - default_prompt = {"default": default_prompt} - elif type(default_prompt) == dict: - pass - else: - raise TypeError("default_prompt must be str or dict") - - # 将文件中的预设值合并到default_prompt中 - for key in __prompts_from_files__: - default_prompt[key] = __prompts_from_files__[key] - - return default_prompt - - -def set_current(name): - global __current__ - for key in get_prompt_dict(): - if key.lower().startswith(name.lower()): - __current__ = key - return - raise KeyError("未找到情景预设: " + name) - - -def get_current(): - global __current__ - return __current__ - - -def set_to_default(): - global __current__ - default_dict = get_prompt_dict() - - if "default" in default_dict: - __current__ = "default" - else: - __current__ = list(default_dict.keys())[0] - - -def get_prompt(name: str = None) -> str: - """获取预设值""" - if name is None: - name = get_current() - - default_dict = get_prompt_dict() - - for key in default_dict: - if key.lower().startswith(name.lower()): - return default_dict[key] - - raise KeyError("未找到情景预设: " + name) diff --git a/pkg/openai~/keymgr.py b/pkg/openai~/keymgr.py deleted file mode 100644 index 7127db8..0000000 --- a/pkg/openai~/keymgr.py +++ /dev/null @@ -1,91 +0,0 @@ -# 此模块提供了维护api-key的各种功能 -import hashlib -import logging - -import pkg.plugin.host as plugin_host -import pkg.plugin.models as plugin_models - - -class KeysManager: - api_key = {} - """所有api-key""" - - using_key = "" - """当前使用的api-key - """ - - alerted = [] - """已提示过超额的key - - 记录在此以避免重复提示 - """ - - exceeded = [] - """已超额的key - - 供自动切换功能识别 - """ - - def get_using_key(self): - return self.using_key - - def get_using_key_md5(self): - return hashlib.md5(self.using_key.encode('utf-8')).hexdigest() - - def __init__(self, api_key): - - if type(api_key) is dict: - self.api_key = api_key - elif type(api_key) is str: - self.api_key = { - "default": api_key - } - elif type(api_key) is list: - for i in range(len(api_key)): - self.api_key[str(i)] = api_key[i] - # 从usage中删除未加载的api-key的记录 - # 不删了,也许会运行时添加曾经有记录的api-key - - self.auto_switch() - - def auto_switch(self) -> (bool, str): - """尝试切换api-key - - Returns: - 是否切换成功, 切换后的api-key的别名 - """ - - for key_name in self.api_key: - if self.api_key[key_name] not in self.exceeded: - self.using_key = self.api_key[key_name] - - logging.info("使用api-key:" + key_name) - - # 触发插件事件 - args = { - "key_name": key_name, - "key_list": self.api_key.keys() - } - _ = plugin_host.emit(plugin_models.KeySwitched, **args) - - return True, key_name - - self.using_key = list(self.api_key.values())[0] - logging.info("使用api-key:" + list(self.api_key.keys())[0]) - - return False, "" - - def add(self, key_name, key): - self.api_key[key_name] = key - - def set_current_exceeded(self): - """设置当前使用的api-key使用量超限 - """ - self.exceeded.append(self.using_key) - - def get_key_name(self, api_key): - """根据api-key获取其别名""" - for key_name in self.api_key: - if self.api_key[key_name] == api_key: - return key_name - return "" \ No newline at end of file diff --git a/pkg/openai~/manager.py b/pkg/openai~/manager.py deleted file mode 100644 index 4a3ceab..0000000 --- a/pkg/openai~/manager.py +++ /dev/null @@ -1,93 +0,0 @@ -import logging - -import openai - -import pkg.openai.keymgr -import pkg.utils.context -import pkg.audit.gatherer -from pkg.openai.modelmgr import ModelRequest, create_openai_model_request - - -class OpenAIInteract: - """OpenAI 接口封装 - - 将文字接口和图片接口封装供调用方使用 - """ - - key_mgr: pkg.openai.keymgr.KeysManager = None - - audit_mgr: pkg.audit.gatherer.DataGatherer = None - - default_image_api_params = { - "size": "256x256", - } - - def __init__(self, api_key: str): - - self.key_mgr = pkg.openai.keymgr.KeysManager(api_key) - self.audit_mgr = pkg.audit.gatherer.DataGatherer() - - logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length()) - - openai.api_key = self.key_mgr.get_using_key() - - pkg.utils.context.set_openai_manager(self) - - # 请求OpenAI Completion - def request_completion(self, prompts) -> str: - """请求补全接口回复 - - Parameters: - prompts (str): 提示语 - - Returns: - str: 回复 - """ - - config = pkg.utils.context.get_config() - - # 根据模型选择使用的接口 - ai: ModelRequest = create_openai_model_request( - config.completion_api_params['model'], - 'user', - config.openai_config["http_proxy"] if "http_proxy" in config.openai_config else None - ) - ai.request( - prompts, - **config.completion_api_params - ) - response = ai.get_response() - - logging.debug("OpenAI response: %s", response) - - if 'model' in config.completion_api_params: - self.audit_mgr.report_text_model_usage(config.completion_api_params['model'], - ai.get_total_tokens()) - elif 'engine' in config.completion_api_params: - self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'], - response['usage']['total_tokens']) - - return ai.get_message() - - def request_image(self, prompt) -> dict: - """请求图片接口回复 - - Parameters: - prompt (str): 提示语 - - Returns: - dict: 响应 - """ - config = pkg.utils.context.get_config() - params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params - - response = openai.Image.create( - prompt=prompt, - n=1, - **params - ) - - self.audit_mgr.report_image_model_usage(params['size']) - - return response - diff --git a/pkg/openai~/modelmgr.py b/pkg/openai~/modelmgr.py deleted file mode 100644 index e67f98c..0000000 --- a/pkg/openai~/modelmgr.py +++ /dev/null @@ -1,184 +0,0 @@ -"""OpenAI 接口底层封装 - -目前使用的对话接口有: -ChatCompletion - gpt-3.5-turbo 等模型 -Completion - text-davinci-003 等模型 -此模块封装此两个接口的请求实现,为上层提供统一的调用方式 -""" -import openai, logging, threading, asyncio -import openai.error as aiE - -COMPLETION_MODELS = { - 'text-davinci-003', - 'text-davinci-002', - 'code-davinci-002', - 'code-cushman-001', - 'text-curie-001', - 'text-babbage-001', - 'text-ada-001', -} - -CHAT_COMPLETION_MODELS = { - 'gpt-3.5-turbo', - 'gpt-3.5-turbo-0301', -} - -EDIT_MODELS = { - -} - -IMAGE_MODELS = { - -} - - -class ModelRequest: - """模型接口请求父类""" - - can_chat = False - runtime: threading.Thread = None - ret = {} - proxy: str = None - request_ready = True - error_info: str = "若在没有任何错误的情况下看到这句话,请带着配置文件上报Issues" - - def __init__(self, model_name, user_name, request_fun, http_proxy:str = None, time_out = None): - self.model_name = model_name - self.user_name = user_name - self.request_fun = request_fun - self.time_out = time_out - if http_proxy != None: - self.proxy = http_proxy - openai.proxy = self.proxy - self.request_ready = False - - async def __a_request__(self, **kwargs): - """异步请求""" - - try: - self.ret:dict = await self.request_fun(**kwargs) - self.request_ready = True - except aiE.APIConnectionError as e: - self.error_info = "{}\n请检查网络连接或代理是否正常".format(e) - raise ConnectionError(self.error_info) - except ValueError as e: - self.error_info = "{}\n该错误可能是由于http_proxy格式设置错误引起的" - except Exception as e: - self.error_info = "{}\n由于请求异常产生的未知错误,请查看日志".format(e) - raise Exception(self.error_info) - - def request(self, **kwargs): - """向接口发起请求""" - - if self.proxy != None: #异步请求 - self.request_ready = False - loop = asyncio.new_event_loop() - self.runtime = threading.Thread( - target=loop.run_until_complete, - args=(self.__a_request__(**kwargs),) - ) - self.runtime.start() - else: #同步请求 - self.ret = self.request_fun(**kwargs) - - def __msg_handle__(self, msg): - """将prompt dict转换成接口需要的格式""" - return msg - - def ret_handle(self): - ''' - API消息返回处理函数 - 若重写该方法,应检查异步线程状态,或在需要检查处super该方法 - ''' - if self.runtime != None and isinstance(self.runtime, threading.Thread): - self.runtime.join(self.time_out) - if self.request_ready: - return - raise Exception(self.error_info) - - def get_total_tokens(self): - try: - return self.ret['usage']['total_tokens'] - except: - return 0 - - def get_message(self): - return self.message - - def get_response(self): - return self.ret - - -class ChatCompletionModel(ModelRequest): - """ChatCompletion接口的请求实现""" - - Chat_role = ['system', 'user', 'assistant'] - def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs): - if http_proxy == None: - request_fun = openai.ChatCompletion.create - else: - request_fun = openai.ChatCompletion.acreate - self.can_chat = True - super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs) - - def request(self, prompts, **kwargs): - prompts = self.__msg_handle__(prompts) - kwargs['messages'] = prompts - super().request(**kwargs) - self.ret_handle() - - def __msg_handle__(self, msgs): - temp_msgs = [] - # 把msgs拷贝进temp_msgs - for msg in msgs: - temp_msgs.append(msg.copy()) - return temp_msgs - - def get_message(self): - return self.ret["choices"][0]["message"]['content'] #需要时直接加载加快请求速度,降低内存消耗 - - -class CompletionModel(ModelRequest): - """Completion接口的请求实现""" - - def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs): - if http_proxy == None: - request_fun = openai.Completion.create - else: - request_fun = openai.Completion.acreate - super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs) - - def request(self, prompts, **kwargs): - prompts = self.__msg_handle__(prompts) - kwargs['prompt'] = prompts - super().request(**kwargs) - self.ret_handle() - - def __msg_handle__(self, msgs): - prompt = '' - for msg in msgs: - prompt = prompt + "{}: {}\n".format(msg['role'], msg['content']) - # for msg in msgs: - # if msg['role'] == 'assistant': - # prompt = prompt + "{}\n".format(msg['content']) - # else: - # prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content']) - prompt = prompt + "assistant: " - return prompt - - def get_message(self): - return self.ret["choices"][0]["text"] - - -def create_openai_model_request(model_name: str, user_name: str = 'user', http_proxy:str = None) -> ModelRequest: - """使用给定的模型名称创建模型请求对象""" - if model_name in CHAT_COMPLETION_MODELS: - model = ChatCompletionModel(model_name, user_name, http_proxy) - elif model_name in COMPLETION_MODELS: - model = CompletionModel(model_name, user_name, http_proxy) - else : - log = "找不到模型[{}],请检查配置文件".format(model_name) - logging.error(log) - raise IndexError(log) - logging.debug("使用接口[{}]创建模型请求[{}]".format(model.__class__.__name__, model_name)) - return model diff --git a/pkg/openai~/pricing.bak.py b/pkg/openai~/pricing.bak.py deleted file mode 100644 index 8a46978..0000000 --- a/pkg/openai~/pricing.bak.py +++ /dev/null @@ -1,28 +0,0 @@ -# 计费模块 -# 已弃用 https://github.com/RockChinQ/QChatGPT/issues/81 - -import logging - -pricing = { - "base": { # 文字模型单位是1000字符 - "text-davinci-003": 0.02, - }, - "image": { - "256x256": 0.016, - "512x512": 0.018, - "1024x1024": 0.02, - } -} - - -def language_base_price(model, text): - salt_rate = 0.93 - length = ((len(text.encode('utf-8')) - len(text)) / 2 + len(text)) * salt_rate - logging.debug("text length: %d" % length) - - return pricing["base"][model] * length / 1000 - - -def image_price(size): - logging.debug("image size: %s" % size) - return pricing["image"][size] diff --git a/pkg/openai~/session.py b/pkg/openai~/session.py deleted file mode 100644 index 38a629d..0000000 --- a/pkg/openai~/session.py +++ /dev/null @@ -1,370 +0,0 @@ -"""主线使用的会话管理模块 - -每个人、每个群单独一个session,session内部保留了对话的上下文, -""" - -import logging -import threading -import time -import json - -import pkg.openai.manager -import pkg.openai.modelmgr -import pkg.database.manager -import pkg.utils.context - -import pkg.plugin.host as plugin_host -import pkg.plugin.models as plugin_models - -# 运行时保存的所有session -sessions = {} - - -class SessionOfflineStatus: - ON_GOING = 'on_going' - EXPLICITLY_CLOSED = 'explicitly_closed' - - -# 重置session.prompt -def reset_session_prompt(session_name, prompt): - # 备份原始数据 - bak_path = 'logs/{}-{}.bak'.format( - session_name, - time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - ) - f = open(bak_path, 'w+') - f.write(prompt) - f.close() - # 生成新数据 - config = pkg.utils.context.get_config() - prompt = [ - { - 'role': 'system', - 'content': config.default_prompt['default'] - } - ] - # 警告 - logging.warning( - """ -用户[{}]的数据已被重置,有可能是因为数据版本过旧或存储错误 -原始数据将备份在: -{}""".format(session_name, bak_path) - ) # 为保证多行文本格式正确故无缩进 - return prompt - - -# 从数据加载session -def load_sessions(): - """从数据库加载sessions""" - - global sessions - - db_inst = pkg.utils.context.get_database_manager() - - session_data = db_inst.load_valid_sessions() - - for session_name in session_data: - logging.info('加载session: {}'.format(session_name)) - - temp_session = Session(session_name) - temp_session.name = session_name - temp_session.create_timestamp = session_data[session_name]['create_timestamp'] - temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp'] - try: - temp_session.prompt = json.loads(session_data[session_name]['prompt']) - except Exception: - temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt']) - temp_session.persistence() - - sessions[session_name] = temp_session - - -# 获取指定名称的session,如果不存在则创建一个新的 -def get_session(session_name: str): - global sessions - if session_name not in sessions: - sessions[session_name] = Session(session_name) - return sessions[session_name] - - -def dump_session(session_name: str): - global sessions - if session_name in sessions: - assert isinstance(sessions[session_name], Session) - sessions[session_name].persistence() - del sessions[session_name] - - -# 通用的OpenAI API交互session -# session内部保留了对话的上下文, -# 收到用户消息后,将上下文提交给OpenAI API生成回复 -class Session: - name = '' - - prompt = [] - """使用list来保存会话中的回合""" - - create_timestamp = 0 - """会话创建时间""" - - last_interact_timestamp = 0 - """上次交互(产生回复)时间""" - - just_switched_to_exist_session = False - - response_lock = None - - # 加锁 - def acquire_response_lock(self): - logging.debug('{},lock acquire,{}'.format(self.name, self.response_lock)) - self.response_lock.acquire() - logging.debug('{},lock acquire successfully,{}'.format(self.name, self.response_lock)) - - # 释放锁 - def release_response_lock(self): - if self.response_lock.locked(): - logging.debug('{},lock release,{}'.format(self.name, self.response_lock)) - self.response_lock.release() - logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock)) - - # 从配置文件获取会话预设信息 - def get_default_prompt(self, use_default: str = None): - config = pkg.utils.context.get_config() - - import pkg.openai.dprompt as dprompt - - if use_default is None: - current_default_prompt = dprompt.get_prompt(dprompt.get_current()) - else: - current_default_prompt = dprompt.get_prompt(use_default) - - return [ - { - 'role': 'user', - 'content': current_default_prompt - }, { - 'role': 'assistant', - 'content': 'ok' - } - ] - - def __init__(self, name: str): - self.name = name - self.create_timestamp = int(time.time()) - self.last_interact_timestamp = int(time.time()) - self.schedule() - - self.response_lock = threading.Lock() - self.prompt = self.get_default_prompt() - - # 设定检查session最后一次对话是否超过过期时间的计时器 - def schedule(self): - threading.Thread(target=self.expire_check_timer_loop, args=(self.create_timestamp,)).start() - - # 检查session是否已经过期 - def expire_check_timer_loop(self, create_timestamp: int): - global sessions - while True: - time.sleep(60) - - # 不是此session已更换,退出 - if self.create_timestamp != create_timestamp or self not in sessions.values(): - return - - config = pkg.utils.context.get_config() - if int(time.time()) - self.last_interact_timestamp > config.session_expire_time: - logging.info('session {} 已过期'.format(self.name)) - - # 触发插件事件 - args = { - 'session_name': self.name, - 'session': self, - 'session_expire_time': config.session_expire_time - } - event = pkg.plugin.host.emit(plugin_models.SessionExpired, **args) - if event.is_prevented_default(): - return - - self.reset(expired=True, schedule_new=False) - - # 删除此session - del sessions[self.name] - return - - # 请求回复 - # 这个函数是阻塞的 - def append(self, text: str) -> str: - """向session中添加一条消息,返回接口回复""" - - self.last_interact_timestamp = int(time.time()) - - # 触发插件事件 - if self.prompt == self.get_default_prompt(): - args = { - 'session_name': self.name, - 'session': self, - 'default_prompt': self.prompt, - } - - event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args) - if event.is_prevented_default(): - return None - - config = pkg.utils.context.get_config() - max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024 - - # 向API请求补全 - message = pkg.utils.context.get_openai_manager().request_completion( - self.cut_out(text, max_length), - ) - - # 成功获取,处理回复 - res_test = message - res_ans = res_test - - # 去除开头可能的提示 - res_ans_spt = res_test.split("\n\n") - if len(res_ans_spt) > 1: - del (res_ans_spt[0]) - res_ans = '\n\n'.join(res_ans_spt) - - # 将此次对话的双方内容加入到prompt中 - self.prompt.append({'role': 'user', 'content': text}) - self.prompt.append({'role': 'assistant', 'content': res_ans}) - - if self.just_switched_to_exist_session: - self.just_switched_to_exist_session = False - self.set_ongoing() - - return res_ans if res_ans[0] != '\n' else res_ans[1:] - - # 删除上一回合并返回上一回合的问题 - def undo(self) -> str: - self.last_interact_timestamp = int(time.time()) - - # 删除最后两个消息 - if len(self.prompt) < 2: - raise Exception('之前无对话,无法撤销') - - question = self.prompt[-2]['content'] - self.prompt = self.prompt[:-2] - - # 返回上一回合的问题 - return question - - # 构建对话体 - def cut_out(self, msg: str, max_tokens: int) -> list: - """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens""" - # 如果用户消息长度超过max_tokens,直接返回 - - temp_prompt = [ - { - 'role': 'user', - 'content': msg - } - ] - - token_count = len(msg) - # 倒序遍历prompt - for i in range(len(self.prompt) - 1, -1, -1): - if token_count >= max_tokens: - break - - # 将prompt加到temp_prompt头部 - temp_prompt.insert(0, self.prompt[i]) - token_count += len(self.prompt[i]['content']) - - logging.debug('cut_out: {}'.format(str(temp_prompt))) - - return temp_prompt - - # 持久化session - def persistence(self): - if self.prompt == self.get_default_prompt(): - return - - db_inst = pkg.utils.context.get_database_manager() - - name_spt = self.name.split('_') - - subject_type = name_spt[0] - subject_number = int(name_spt[1]) - - db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, - json.dumps(self.prompt)) - - # 重置session - def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None): - if self.prompt[-1]['role'] != "system": - self.persistence() - if explicit: - # 触发插件事件 - args = { - 'session_name': self.name, - 'session': self - } - - # 此事件不支持阻止默认行为 - _ = pkg.plugin.host.emit(plugin_models.SessionExplicitReset, **args) - - pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp) - - if expired: - pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp) - self.prompt = self.get_default_prompt(use_prompt) - self.create_timestamp = int(time.time()) - self.last_interact_timestamp = int(time.time()) - self.just_switched_to_exist_session = False - - # self.response_lock = threading.Lock() - - if schedule_new: - self.schedule() - - # 将本session的数据库状态设置为on_going - def set_ongoing(self): - pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp) - - # 切换到上一个session - def last_session(self): - last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp) - if last_one is None: - return None - else: - self.persistence() - - self.create_timestamp = last_one['create_timestamp'] - self.last_interact_timestamp = last_one['last_interact_timestamp'] - try: - self.prompt = json.loads(last_one['prompt']) - except json.decoder.JSONDecodeError: - self.prompt = reset_session_prompt(self.name, last_one['prompt']) - self.persistence() - - self.just_switched_to_exist_session = True - return self - - # 切换到下一个session - def next_session(self): - next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp) - if next_one is None: - return None - else: - self.persistence() - - self.create_timestamp = next_one['create_timestamp'] - self.last_interact_timestamp = next_one['last_interact_timestamp'] - try: - self.prompt = json.loads(next_one['prompt']) - except json.decoder.JSONDecodeError: - self.prompt = reset_session_prompt(self.name, next_one['prompt']) - self.persistence() - - self.just_switched_to_exist_session = True - return self - - def list_history(self, capacity: int = 10, page: int = 0): - return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page) - - def draw_image(self, prompt: str): - return pkg.utils.context.get_openai_manager().request_image(prompt) diff --git a/pkg/qqbot~/__init__.py b/pkg/qqbot~/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pkg/qqbot~/banlist.py b/pkg/qqbot~/banlist.py deleted file mode 100644 index 2c7dcb1..0000000 --- a/pkg/qqbot~/banlist.py +++ /dev/null @@ -1,50 +0,0 @@ -import pkg.utils.context - - -def is_banned(launcher_type: str, launcher_id: int, sender_id: int) -> bool: - if not pkg.utils.context.get_qqbot_manager().enable_banlist: - return False - - result = False - - if launcher_type == 'group': - # 检查是否显式声明发起人QQ要被person忽略 - if sender_id in pkg.utils.context.get_qqbot_manager().ban_person: - result = True - else: - for group_rule in pkg.utils.context.get_qqbot_manager().ban_group: - if type(group_rule) == int: - if group_rule == launcher_id: # 此群群号被禁用 - result = True - elif type(group_rule) == str: - if group_rule.startswith('!'): - # 截取!后面的字符串作为表达式,判断是否匹配 - reg_str = group_rule[1:] - import re - if re.match(reg_str, str(launcher_id)): # 被豁免,最高级别 - result = False - break - else: - # 判断是否匹配regexp - import re - if re.match(group_rule, str(launcher_id)): # 此群群号被禁用 - result = True - - else: - # ban_person, 与群规则相同 - for person_rule in pkg.utils.context.get_qqbot_manager().ban_person: - if type(person_rule) == int: - if person_rule == launcher_id: - result = True - elif type(person_rule) == str: - if person_rule.startswith('!'): - reg_str = person_rule[1:] - import re - if re.match(reg_str, str(launcher_id)): - result = False - break - else: - import re - if re.match(person_rule, str(launcher_id)): - result = True - return result diff --git a/pkg/qqbot~/blob.py b/pkg/qqbot~/blob.py deleted file mode 100644 index c6edff2..0000000 --- a/pkg/qqbot~/blob.py +++ /dev/null @@ -1,105 +0,0 @@ -# 长消息处理相关 -import logging -import os -import time -import base64 - -import config -from mirai.models.message import MessageComponent, MessageChain, Image -from mirai.models.message import ForwardMessageNode -from mirai.models.base import MiraiBaseModel -from typing import List -import pkg.utils.context as context -import pkg.utils.text2img as text2img - - -class ForwardMessageDiaplay(MiraiBaseModel): - title: str = "群聊的聊天记录" - brief: str = "[聊天记录]" - source: str = "聊天记录" - preview: List[str] = [] - summary: str = "查看x条转发消息" - - -class Forward(MessageComponent): - """合并转发。""" - type: str = "Forward" - """消息组件类型。""" - display: ForwardMessageDiaplay - """显示信息""" - node_list: List[ForwardMessageNode] - """转发消息节点列表。""" - def __init__(self, *args, **kwargs): - if len(args) == 1: - self.node_list = args[0] - super().__init__(**kwargs) - super().__init__(*args, **kwargs) - - def __str__(self): - return '[聊天记录]' - - -def text_to_image(text: str) -> MessageComponent: - """将文本转换成图片""" - # 检查temp文件夹是否存在 - if not os.path.exists('temp'): - os.mkdir('temp') - img_path = text2img.text_to_image(text_str=text, save_as='temp/{}.png'.format(int(time.time()))) - - compressed_path, size = text2img.compress_image(img_path, outfile="temp/{}_compressed.png".format(int(time.time()))) - # 读取图片,转换成base64 - with open(compressed_path, 'rb') as f: - img = f.read() - - b64 = base64.b64encode(img) - - # 删除图片 - os.remove(img_path) - - # 判断compressed_path是否存在 - if os.path.exists(compressed_path): - os.remove(compressed_path) - # 返回图片 - return Image(base64=b64.decode('utf-8')) - - -def check_text(text: str) -> list: - """检查文本是否为长消息,并转换成该使用的消息链组件""" - if not hasattr(config, 'blob_message_threshold'): - return [text] - - if len(text) > config.blob_message_threshold: - if not hasattr(config, 'blob_message_strategy'): - raise AttributeError('未定义长消息处理策略') - - # logging.info("长消息: {}".format(text)) - if config.blob_message_strategy == 'image': - # 转换成图片 - return [text_to_image(text)] - elif config.blob_message_strategy == 'forward': - # 敏感词屏蔽 - text = context.get_qqbot_manager().reply_filter.process(text) - - # 包装转发消息 - display = ForwardMessageDiaplay( - title='群聊的聊天记录', - brief='[聊天记录]', - source='聊天记录', - preview=["bot: "+text], - summary="查看1条转发消息" - ) - - node = ForwardMessageNode( - sender_id=config.mirai_http_api_config['qq'], - sender_name='bot', - message_chain=MessageChain([text]) - ) - - forward = Forward( - display=display, - node_list=[node] - ) - - return [forward] - else: - return [text] \ No newline at end of file diff --git a/pkg/qqbot~/command.py b/pkg/qqbot~/command.py deleted file mode 100644 index b174d45..0000000 --- a/pkg/qqbot~/command.py +++ /dev/null @@ -1,359 +0,0 @@ -# 指令处理模块 -import logging -import json -import datetime -import os -import threading - -import pkg.openai.session -import pkg.openai.manager -import pkg.utils.reloader -import pkg.utils.updater -import pkg.utils.context -import pkg.qqbot.message -import pkg.utils.credit as credit - -from mirai import Image - - -def config_operation(cmd, params): - reply = [] - config = pkg.utils.context.get_config() - reply_str = "" - if len(params) == 0: - reply = ["[bot]err:请输入配置项"] - else: - cfg_name = params[0] - if cfg_name == 'all': - reply_str = "[bot]所有配置项:\n\n" - for cfg in dir(config): - if not cfg.startswith('__') and not cfg == 'logging': - # 根据配置项类型进行格式化,如果是字典则转换为json并格式化 - if isinstance(getattr(config, cfg), str): - reply_str += "{}: \"{}\"\n".format(cfg, getattr(config, cfg)) - elif isinstance(getattr(config, cfg), dict): - # 不进行unicode转义,并格式化 - reply_str += "{}: {}\n".format(cfg, - json.dumps(getattr(config, cfg), - ensure_ascii=False, indent=4)) - else: - reply_str += "{}: {}\n".format(cfg, getattr(config, cfg)) - reply = [reply_str] - elif cfg_name in dir(config): - if len(params) == 1: - # 按照配置项类型进行格式化 - if isinstance(getattr(config, cfg_name), str): - reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, getattr(config, cfg_name)) - elif isinstance(getattr(config, cfg_name), dict): - reply_str = "[bot]配置项{}: {}\n".format(cfg_name, - json.dumps(getattr(config, cfg_name), - ensure_ascii=False, indent=4)) - else: - reply_str = "[bot]配置项{}: {}\n".format(cfg_name, getattr(config, cfg_name)) - reply = [reply_str] - else: - cfg_value = " ".join(params[1:]) - # 类型转换,如果是json则转换为字典 - if cfg_value == 'true': - cfg_value = True - elif cfg_value == 'false': - cfg_value = False - elif cfg_value.isdigit(): - cfg_value = int(cfg_value) - elif cfg_value.startswith('{') and cfg_value.endswith('}'): - cfg_value = json.loads(cfg_value) - else: - try: - cfg_value = float(cfg_value) - except ValueError: - pass - - # 检查类型是否匹配 - if isinstance(getattr(config, cfg_name), type(cfg_value)): - setattr(config, cfg_name, cfg_value) - pkg.utils.context.set_config(config) - reply = ["[bot]配置项{}修改成功".format(cfg_name)] - else: - reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)] - - else: - reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] - - return reply - - -def plugin_operation(cmd, params, is_admin): - reply = [] - - import pkg.plugin.host as plugin_host - import pkg.utils.updater as updater - - plugin_list = plugin_host.__plugins__ - - if len(params) == 0: - reply_str = "[bot]所有插件({}):\n".format(len(plugin_host.__plugins__)) - idx = 0 - for key in plugin_host.iter_plugins_name(): - plugin = plugin_list[key] - reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\ - .format((idx+1), plugin['name'], - "[已禁用]" if not plugin['enabled'] else "", - plugin['description'], - plugin['version'], plugin['author']) - - if updater.is_repo("/".join(plugin['path'].split('/')[:-1])): - remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1])) - if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT": - reply_str += "源码: "+remote_url+"\n" - - idx += 1 - - reply = [reply_str] - elif params[0] == 'update': - # 更新所有插件 - if is_admin: - def closure(): - import pkg.utils.context - updated = [] - for key in plugin_list: - plugin = plugin_list[key] - if updater.is_repo("/".join(plugin['path'].split('/')[:-1])): - success = updater.pull_latest("/".join(plugin['path'].split('/')[:-1])) - if success: - updated.append(plugin['name']) - - # 检查是否有requirements.txt - pkg.utils.context.get_qqbot_manager().notify_admin("正在安装依赖...") - for key in plugin_list: - plugin = plugin_list[key] - if os.path.exists("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt"): - logging.info("{}检测到requirements.txt,安装依赖".format(plugin['name'])) - import pkg.utils.pkgmgr - pkg.utils.pkgmgr.install_requirements("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt") - - import main - main.reset_logging() - - pkg.utils.context.get_qqbot_manager().notify_admin("已更新插件: {}".format(", ".join(updated))) - - threading.Thread(target=closure).start() - reply = ["[bot]正在更新所有插件,请勿重复发起..."] - else: - reply = ["[bot]err:权限不足"] - elif params[0].startswith("http"): - if is_admin: - - def closure(): - try: - plugin_host.install_plugin(params[0]) - pkg.utils.context.get_qqbot_manager().notify_admin("插件安装成功,请发送 !reload 指令重载插件") - except Exception as e: - logging.error("插件安装失败:{}".format(e)) - pkg.utils.context.get_qqbot_manager().notify_admin("插件安装失败:{}".format(e)) - - threading.Thread(target=closure, args=()).start() - reply = ["[bot]正在安装插件..."] - else: - reply = ["[bot]err:权限不足,请使用管理员账号私聊发起"] - return reply - - -def process_command(session_name: str, text_message: str, mgr, config, - launcher_type: str, launcher_id: int, sender_id: int, is_admin: bool) -> list: - reply = [] - try: - logging.info( - "[{}]发起指令:{}".format(session_name, text_message[:min(20, len(text_message))] + ( - "..." if len(text_message) > 20 else ""))) - - cmd = text_message[1:].strip().split(' ')[0] - - params = text_message[1:].strip().split(' ')[1:] - if cmd == 'help': - reply = ["[bot]" + config.help_message] - elif cmd == 'reset': - if len(params) == 0: - pkg.openai.session.get_session(session_name).reset(explicit=True) - reply = ["[bot]会话已重置"] - else: - pkg.openai.session.get_session(session_name).reset(explicit=True, use_prompt=params[0]) - reply = ["[bot]会话已重置,使用场景预设:{}".format(params[0])] - elif cmd == 'last': - result = pkg.openai.session.get_session(session_name).last_session() - if result is None: - reply = ["[bot]没有前一次的对话"] - else: - datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( - '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(datetime_str)] - elif cmd == 'next': - result = pkg.openai.session.get_session(session_name).next_session() - if result is None: - reply = ["[bot]没有后一次的对话"] - else: - datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( - '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(datetime_str)] - elif cmd == 'prompt': - msgs = "" - session:list = pkg.openai.session.get_session(session_name).prompt - for msg in session: - if len(params) != 0 and params[0] in ['-all', '-a']: - msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content']) - elif len(msg['content']) > 30: - msgs = msgs + "[{}]: {}...\n\n".format(msg['role'], msg['content'][:30]) - else: - msgs = msgs + "[{}]: {}\n\n".format(msg['role'], msg['content']) - reply = ["[bot]当前对话所有内容:\n{}".format(msgs)] - elif cmd == 'list': - pkg.openai.session.get_session(session_name).persistence() - page = 0 - - if len(params) > 0: - try: - page = int(params[0]) - except ValueError: - pass - - results = pkg.openai.session.get_session(session_name).list_history(page=page) - if len(results) == 0: - reply = ["[bot]第{}页没有历史会话".format(page)] - else: - reply_str = "[bot]历史会话 第{}页:\n".format(page) - current = -1 - for i in range(len(results)): - # 时间(使用create_timestamp转换) 序号 部分内容 - datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp']) - msg = "" - try: - msg = json.loads(results[i]['prompt']) - except json.decoder.JSONDecodeError: - msg = pkg.openai.session.reset_session_prompt(session_name, results[i]['prompt']) - # 持久化 - pkg.openai.session.get_session(session_name).persistence() - if len(msg) >= 2: - reply_str += "#{} 创建:{} {}\n".format(i + page * 10, - datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - msg[1]['content']) - else: - reply_str += "#{} 创建:{} {}\n".format(i + page * 10, - datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - "无内容") - if results[i]['create_timestamp'] == pkg.openai.session.get_session( - session_name).create_timestamp: - current = i + page * 10 - - reply_str += "\n以上信息倒序排列" - if current != -1: - reply_str += ",当前会话是 #{}\n".format(current) - else: - reply_str += ",当前处于全新会话或不在此页" - - reply = [reply_str] - elif cmd == 'resend': - session = pkg.openai.session.get_session(session_name) - to_send = session.undo() - - reply = pkg.qqbot.message.process_normal_message(to_send, mgr, config, - launcher_type, launcher_id, sender_id) - elif cmd == 'usage': - reply_str = "[bot]各api-key使用情况:\n\n" - - api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key - for key_name in api_keys: - text_length = pkg.utils.context.get_openai_manager().audit_mgr \ - .get_text_length_of_key(api_keys[key_name]) - image_count = pkg.utils.context.get_openai_manager().audit_mgr \ - .get_image_count_of_key(api_keys[key_name]) - reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length), - int(image_count)) - # 获取此key的额度 - try: - credit_data = credit.fetch_credit_data(api_keys[key_name]) - reply_str += " - 使用额度:{:.2f}/{:.2f}\n".format(credit_data['total_used'],credit_data['total_granted']) - except Exception as e: - logging.warning("获取额度失败:{}".format(e)) - - reply = [reply_str] - elif cmd == 'draw': - if len(params) == 0: - reply = ["[bot]err:请输入图片描述文字"] - else: - session = pkg.openai.session.get_session(session_name) - - res = session.draw_image(" ".join(params)) - - logging.debug("draw_image result:{}".format(res)) - reply = [Image(url=res['data'][0]['url'])] - if not (hasattr(config, 'include_image_description') - and not config.include_image_description): - reply.append(" ".join(params)) - elif cmd == 'version': - reply_str = "[bot]当前版本:\n{}\n".format(pkg.utils.updater.get_current_version_info()) - try: - if pkg.utils.updater.is_new_version_available(): - reply_str += "\n有新版本可用,请使用命令 !update 进行更新" - except: - pass - - reply = [reply_str] - - elif cmd == 'plugin': - reply = plugin_operation(cmd, params, is_admin) - - elif cmd == 'default': - if len(params) == 0: - # 输出目前所有情景预设 - import pkg.openai.dprompt as dprompt - reply_str = "[bot]当前所有情景预设:\n\n" - for key,value in dprompt.get_prompt_dict().items(): - reply_str += " - {}: {}\n".format(key,value) - - reply_str += "\n当前默认情景预设:{}\n".format(dprompt.get_current()) - reply_str += "请使用!default <情景预设>来设置默认情景预设" - reply = [reply_str] - elif len(params) >0 and is_admin: - # 设置默认情景 - import pkg.openai.dprompt as dprompt - try: - dprompt.set_current(params[0]) - reply = ["[bot]已设置默认情景预设为:{}".format(dprompt.get_current())] - except KeyError: - reply = ["[bot]err: 未找到情景预设:{}".format(params[0])] - else: - reply = ["[bot]err: 仅管理员可设置默认情景预设"] - elif cmd == 'reload' and is_admin: - def reload_task(): - pkg.utils.reloader.reload_all() - - threading.Thread(target=reload_task, daemon=True).start() - elif cmd == 'update' and is_admin: - def update_task(): - try: - if pkg.utils.updater.update_all(): - pkg.utils.reloader.reload_all(notify=False) - pkg.utils.context.get_qqbot_manager().notify_admin("更新完成") - else: - pkg.utils.context.get_qqbot_manager().notify_admin("无新版本") - except Exception as e0: - pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0)) - return - - threading.Thread(target=update_task, daemon=True).start() - - reply = ["[bot]正在更新,请耐心等待,请勿重复发起更新..."] - elif cmd == 'cfg' and is_admin: - reply = config_operation(cmd, params) - else: - if cmd.startswith("~") and is_admin: - config_item = cmd[1:] - params = [config_item] + params - reply = config_operation("cfg", params) - else: - reply = ["[bot]err:未知的指令或权限不足: " + cmd] - except Exception as e: - mgr.notify_admin("{}指令执行失败:{}".format(session_name, e)) - logging.exception(e) - reply = ["[bot]err:{}".format(e)] - - return reply diff --git a/pkg/qqbot~/filter.py b/pkg/qqbot~/filter.py deleted file mode 100644 index f0efeda..0000000 --- a/pkg/qqbot~/filter.py +++ /dev/null @@ -1,84 +0,0 @@ -# 敏感词过滤模块 -import re -import requests -import json -import logging - - -class ReplyFilter: - sensitive_words = [] - mask = "*" - mask_word = "" - - # 默认值( 兼容性考虑 ) - baidu_check = False - baidu_api_key = "" - baidu_secret_key = "" - inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规" - - def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""): - self.sensitive_words = sensitive_words - self.mask = mask - self.mask_word = mask_word - import config - if hasattr(config, 'baidu_check') and hasattr(config, 'baidu_api_key') and hasattr(config, 'baidu_secret_key'): - self.baidu_check = config.baidu_check - self.baidu_api_key = config.baidu_api_key - self.baidu_secret_key = config.baidu_secret_key - self.inappropriate_message_tips = config.inappropriate_message_tips - - def is_illegal(self, message: str) -> bool: - processed = self.process(message) - if processed != message: - return True - return False - - def process(self, message: str) -> str: - - # 本地关键词屏蔽 - for word in self.sensitive_words: - match = re.findall(word, message) - if len(match) > 0: - for i in range(len(match)): - if self.mask_word == "": - message = message.replace(match[i], self.mask * len(match[i])) - else: - message = message.replace(match[i], self.mask_word) - - # 百度云审核 - if self.baidu_check: - - # 百度云审核URL - baidu_url = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token=" + \ - str(requests.post("https://aip.baidubce.com/oauth/2.0/token", - params={"grant_type": "client_credentials", - "client_id": self.baidu_api_key, - "client_secret": self.baidu_secret_key}).json().get("access_token")) - - # 百度云审核 - payload = "text=" + message - logging.info("向百度云发送:" + payload) - headers = {'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'} - - if isinstance(payload, str): - payload = payload.encode('utf-8') - - response = requests.request("POST", baidu_url, headers=headers, data=payload) - response_dict = json.loads(response.text) - - if "error_code" in response_dict: - error_msg = response_dict.get("error_msg") - logging.warning(f"百度云判定出错,错误信息:{error_msg}") - conclusion = f"百度云判定出错,错误信息:{error_msg}\n以下是原消息:{message}" - else: - conclusion = response_dict["conclusion"] - if conclusion in ("合规"): - logging.info(f"百度云判定结果:{conclusion}") - return message - else: - logging.warning(f"百度云判定结果:{conclusion}") - conclusion = self.inappropriate_message_tips - # 返回百度云审核结果 - return conclusion - - return message diff --git a/pkg/qqbot~/ignore.py b/pkg/qqbot~/ignore.py deleted file mode 100644 index 01994b2..0000000 --- a/pkg/qqbot~/ignore.py +++ /dev/null @@ -1,19 +0,0 @@ -import re - - -def ignore(msg: str) -> bool: - """检查消息是否应该被忽略""" - import config - - if not hasattr(config, 'ignore_rules'): - return False - - if 'prefix' in config.ignore_rules: - for rule in config.ignore_rules['prefix']: - if msg.startswith(rule): - return True - - if 'regexp' in config.ignore_rules: - for rule in config.ignore_rules['regexp']: - if re.search(rule, msg): - return True diff --git a/pkg/qqbot~/manager.py b/pkg/qqbot~/manager.py deleted file mode 100644 index 5d817ee..0000000 --- a/pkg/qqbot~/manager.py +++ /dev/null @@ -1,357 +0,0 @@ -import asyncio -import json -import os -import threading -from concurrent.futures import ThreadPoolExecutor - -import mirai.models.bus -from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \ - FriendMessage, Image -from func_timeout import func_set_timeout - -import pkg.openai.session -import pkg.openai.manager -from func_timeout import FunctionTimedOut -import logging - -import pkg.qqbot.filter -import pkg.qqbot.process as processor -import pkg.utils.context - -import pkg.plugin.host as plugin_host -import pkg.plugin.models as plugin_models - - -# 检查消息是否符合泛响应匹配机制 -def check_response_rule(text: str): - config = pkg.utils.context.get_config() - if not hasattr(config, 'response_rules'): - return False, '' - - rules = config.response_rules - # 检查前缀匹配 - if 'prefix' in rules: - for rule in rules['prefix']: - if text.startswith(rule): - return True, text.replace(rule, "", 1) - - # 检查正则表达式匹配 - if 'regexp' in rules: - for rule in rules['regexp']: - import re - match = re.match(rule, text) - if match: - return True, text - - return False, "" - - -def response_at(): - config = pkg.utils.context.get_config() - if 'at' not in config.response_rules: - return True - - return config.response_rules['at'] - - -def random_responding(): - config = pkg.utils.context.get_config() - if 'random_rate' in config.response_rules: - import random - return random.random() < config.response_rules['random_rate'] - return False - - -# 控制QQ消息输入输出的类 -class QQBotManager: - retry = 3 - - #线程池控制 - pool = None - - bot: Mirai = None - - reply_filter = None - - enable_banlist = False - - ban_person = [] - ban_group = [] - - def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True): - self.timeout = timeout - self.retry = retry - - self.pool_num = pool_num - self.pool = ThreadPoolExecutor(max_workers=self.pool_num) - logging.debug("Registered thread pool Size:{}".format(pool_num)) - - # 加载禁用列表 - if os.path.exists("banlist.py"): - import banlist - self.enable_banlist = banlist.enable - self.ban_person = banlist.person - self.ban_group = banlist.group - logging.info("加载禁用列表: person: {}, group: {}".format(self.ban_person, self.ban_group)) - - config = pkg.utils.context.get_config() - if os.path.exists("sensitive.json") \ - and config.sensitive_word_filter is not None \ - and config.sensitive_word_filter: - with open("sensitive.json", "r", encoding="utf-8") as f: - sensitive_json = json.load(f) - self.reply_filter = pkg.qqbot.filter.ReplyFilter( - sensitive_words=sensitive_json['words'], - mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*', - mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else '' - ) - else: - self.reply_filter = pkg.qqbot.filter.ReplyFilter([]) - - # 由于YiriMirai的bot对象是单例的,且shutdown方法暂时无法使用 - # 故只在第一次初始化时创建bot对象,重载之后使用原bot对象 - # 因此,bot的配置不支持热重载 - if first_time_init: - self.first_time_init(mirai_http_api_config) - else: - self.bot = pkg.utils.context.get_qqbot_manager().bot - - pkg.utils.context.set_qqbot_manager(self) - - # Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码 - @self.bot.on(FriendMessage) - async def on_friend_message(event: FriendMessage): - - def friend_message_handler(event: FriendMessage): - - # 触发事件 - args = { - "launcher_type": "person", - "launcher_id": event.sender.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args) - - if plugin_event.is_prevented_default(): - return - - self.on_person_message(event) - - self.go(friend_message_handler, event) - - @self.bot.on(StrangerMessage) - async def on_stranger_message(event: StrangerMessage): - - def stranger_message_handler(event: StrangerMessage): - # 触发事件 - args = { - "launcher_type": "person", - "launcher_id": event.sender.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args) - - if plugin_event.is_prevented_default(): - return - - self.on_person_message(event) - - self.go(stranger_message_handler, event) - - @self.bot.on(GroupMessage) - async def on_group_message(event: GroupMessage): - - def group_message_handler(event: GroupMessage): - # 触发事件 - args = { - "launcher_type": "group", - "launcher_id": event.group.id, - "sender_id": event.sender.id, - "message_chain": event.message_chain, - } - plugin_event = plugin_host.emit(plugin_models.GroupMessageReceived, **args) - - if plugin_event.is_prevented_default(): - return - - self.on_group_message(event) - - self.go(group_message_handler, event) - - def unsubscribe_all(): - """取消所有订阅 - - 用于在热重载流程中卸载所有事件处理器 - """ - assert isinstance(self.bot, Mirai) - bus = self.bot.bus - assert isinstance(bus, mirai.models.bus.ModelEventBus) - - bus.unsubscribe(FriendMessage, on_friend_message) - bus.unsubscribe(StrangerMessage, on_stranger_message) - bus.unsubscribe(GroupMessage, on_group_message) - - self.unsubscribe_all = unsubscribe_all - - def go(self, func, *args, **kwargs): - self.pool.submit(func, *args, **kwargs) - - def first_time_init(self, mirai_http_api_config: dict): - """热重载后不再运行此函数""" - - if 'adapter' not in mirai_http_api_config or mirai_http_api_config['adapter'] == "WebSocketAdapter": - bot = Mirai( - qq=mirai_http_api_config['qq'], - adapter=WebSocketAdapter( - verify_key=mirai_http_api_config['verifyKey'], - host=mirai_http_api_config['host'], - port=mirai_http_api_config['port'] - ) - ) - elif mirai_http_api_config['adapter'] == "HTTPAdapter": - bot = Mirai( - qq=mirai_http_api_config['qq'], - adapter=HTTPAdapter( - verify_key=mirai_http_api_config['verifyKey'], - host=mirai_http_api_config['host'], - port=mirai_http_api_config['port'] - ) - ) - - else: - raise Exception("未知的适配器类型") - - self.bot = bot - - def send(self, event, msg, check_quote=True): - config = pkg.utils.context.get_config() - asyncio.run( - self.bot.send(event, msg, quote=True if hasattr(config, - "quote_origin") and config.quote_origin and check_quote else False)) - - # 私聊消息处理 - def on_person_message(self, event: MessageEvent): - import config - reply = '' - - if event.sender.id == self.bot.qq: - pass - else: - if Image in event.message_chain: - pass - else: - # 超时则重试,重试超过次数则放弃 - failed = 0 - for i in range(self.retry): - try: - - @func_set_timeout(config.process_message_timeout) - def time_ctrl_wrapper(): - reply = processor.process_message('person', event.sender.id, str(event.message_chain), - event.message_chain, - event.sender.id) - return reply - - reply = time_ctrl_wrapper() - break - except FunctionTimedOut: - logging.warning("person_{}: 超时,重试中({})".format(event.sender.id, i)) - pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock() - if "person_{}".format(event.sender.id) in pkg.qqbot.process.processing: - pkg.qqbot.process.processing.remove('person_{}'.format(event.sender.id)) - failed += 1 - continue - - if failed == self.retry: - pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock() - self.notify_admin("{} 请求超时".format("person_{}".format(event.sender.id))) - reply = ["[bot]err:请求超时"] - - if reply: - return self.send(event, reply, check_quote=False) - - # 群消息处理 - def on_group_message(self, event: GroupMessage): - import config - reply = '' - - def process(text=None) -> str: - replys = "" - if At(self.bot.qq) in event.message_chain: - event.message_chain.remove(At(self.bot.qq)) - - # 超时则重试,重试超过次数则放弃 - failed = 0 - for i in range(self.retry): - try: - @func_set_timeout(config.process_message_timeout) - def time_ctrl_wrapper(): - replys = processor.process_message('group', event.group.id, - str(event.message_chain).strip() if text is None else text, - event.message_chain, - event.sender.id) - return replys - - replys = time_ctrl_wrapper() - break - except FunctionTimedOut: - logging.warning("group_{}: 超时,重试中({})".format(event.group.id, i)) - pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock() - if "group_{}".format(event.group.id) in pkg.qqbot.process.processing: - pkg.qqbot.process.processing.remove('group_{}'.format(event.group.id)) - failed += 1 - continue - - if failed == self.retry: - pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock() - self.notify_admin("{} 请求超时".format("group_{}".format(event.group.id))) - replys = ["[bot]err:请求超时"] - - return replys - - if Image in event.message_chain: - pass - else: - if At(self.bot.qq) in event.message_chain and response_at(): - # 直接调用 - reply = process() - else: - check, result = check_response_rule(str(event.message_chain).strip()) - - if check: - reply = process(result.strip()) - # 检查是否随机响应 - elif random_responding(): - logging.info("随机响应group_{}消息".format(event.group.id)) - reply = process() - - if reply: - return self.send(event, reply) - - # 通知系统管理员 - def notify_admin(self, message: str): - config = pkg.utils.context.get_config() - if hasattr(config, "admin_qq") and config.admin_qq != 0 and config.admin_qq != []: - logging.info("通知管理员:{}".format(message)) - if type(config.admin_qq) == int: - send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message)) - threading.Thread(target=asyncio.run, args=(send_task,)).start() - else: - for adm in config.admin_qq: - send_task = self.bot.send_friend_message(adm, "[bot]{}".format(message)) - threading.Thread(target=asyncio.run, args=(send_task,)).start() - - - def notify_admin_message_chain(self, message): - config = pkg.utils.context.get_config() - if hasattr(config, "admin_qq") and config.admin_qq != 0 and config.admin_qq != []: - logging.info("通知管理员:{}".format(message)) - if type(config.admin_qq) == int: - send_task = self.bot.send_friend_message(config.admin_qq, message) - threading.Thread(target=asyncio.run, args=(send_task,)).start() - else: - for adm in config.admin_qq: - send_task = self.bot.send_friend_message(adm, message) - threading.Thread(target=asyncio.run, args=(send_task,)).start() diff --git a/pkg/qqbot~/message.py b/pkg/qqbot~/message.py deleted file mode 100644 index e6106df..0000000 --- a/pkg/qqbot~/message.py +++ /dev/null @@ -1,130 +0,0 @@ -# 普通消息处理模块 -import logging -import time -import openai -import pkg.utils.context -import pkg.openai.session - -import pkg.plugin.host as plugin_host -import pkg.plugin.models as plugin_models -import pkg.qqbot.blob as blob - - -def handle_exception(notify_admin: str = "", set_reply: str = "") -> list: - """处理异常,当notify_admin不为空时,会通知管理员,返回通知用户的消息""" - import config - pkg.utils.context.get_qqbot_manager().notify_admin(notify_admin) - if hasattr(config, 'hide_exce_info_to_user') and config.hide_exce_info_to_user: - if hasattr(config, 'alter_tip_message'): - return [config.alter_tip_message] if config.alter_tip_message else [] - else: - return ["[bot]出错了,请重试或联系管理员"] - else: - return [set_reply] - - -def process_normal_message(text_message: str, mgr, config, launcher_type: str, - launcher_id: int, sender_id: int) -> list: - session_name = f"{launcher_type}_{launcher_id}" - logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + ( - "..." if len(text_message) > 20 else ""))) - - session = pkg.openai.session.get_session(session_name) - - unexpected_exception_times = 0 - - max_unexpected_exception_times = 3 - - reply = [] - while True: - if unexpected_exception_times >= max_unexpected_exception_times: - reply = handle_exception(notify_admin=f"{session_name},多次尝试失败。", set_reply=f"[bot]多次尝试失败,请重试或联系管理员") - break - try: - prefix = "[GPT]" if hasattr(config, "show_prefix") and config.show_prefix else "" - - text = session.append(text_message) - - # 触发插件事件 - args = { - "launcher_type": launcher_type, - "launcher_id": launcher_id, - "sender_id": sender_id, - "session": session, - "prefix": prefix, - "response_text": text - } - - event = pkg.plugin.host.emit(plugin_models.NormalMessageResponded, **args) - - if event.get_return_value("prefix") is not None: - prefix = event.get_return_value("prefix") - - if event.get_return_value("reply") is not None: - reply = event.get_return_value("reply") - - if not event.is_prevented_default(): - reply = blob.check_text(prefix + text) - except openai.error.APIConnectionError as e: - err_msg = str(e) - if err_msg.__contains__('Error communicating with OpenAI'): - reply = handle_exception("{}会话调用API失败:{}\n请尝试关闭网络代理来解决此问题。".format(session_name, e), - "[bot]err:调用API失败,请重试或联系管理员,或等待修复") - else: - reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), "[bot]err:调用API失败,请重试或联系管理员,或等待修复") - except openai.error.RateLimitError as e: - logging.debug(type(e)) - logging.debug(e.error['message']) - - if 'message' in e.error and e.error['message'].__contains__('You exceeded your current quota'): - # 尝试切换api-key - current_key_name = pkg.utils.context.get_openai_manager().key_mgr.get_key_name( - pkg.utils.context.get_openai_manager().key_mgr.using_key - ) - pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded() - - # 触发插件事件 - args = { - 'key_name': current_key_name, - 'usage': pkg.utils.context.get_openai_manager().audit_mgr - .get_usage(pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()), - 'exceeded_keys': pkg.utils.context.get_openai_manager().key_mgr.exceeded, - } - event = plugin_host.emit(plugin_models.KeyExceeded, **args) - - if not event.is_prevented_default(): - switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch() - - if not switched: - reply = handle_exception( - "api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key;如果你认为这是误判,请尝试重启程序。".format( - current_key_name), "[bot]err:API调用额度超额,请联系管理员,或等待修复") - else: - openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key() - mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name)) - reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"] - continue - elif 'message' in e.error and e.error['message'].__contains__('You can retry your request'): - # 重试 - unexpected_exception_times += 1 - continue - elif 'message' in e.error and e.error['message']\ - .__contains__('The server had an error while processing your request'): - # 重试 - unexpected_exception_times += 1 - continue - else: - reply = handle_exception("{}会话调用API失败:{}".format(session_name, e), - "[bot]err:RateLimitError,请重试或联系作者,或等待修复") - except openai.error.InvalidRequestError as e: - reply = handle_exception("{}API调用参数错误:{}\n\n这可能是由于config.py中的prompt_submit_length参数或" - "completion_api_params中的max_tokens参数数值过大导致的,请尝试将其降低".format( - session_name, e), "[bot]err:API调用参数错误,请联系管理员,或等待修复") - except openai.error.ServiceUnavailableError as e: - reply = handle_exception("{}API调用服务不可用:{}".format(session_name, e), "[bot]err:API调用服务不可用,请重试或联系管理员,或等待修复") - except Exception as e: - logging.exception(e) - reply = handle_exception("{}会话处理异常:{}".format(session_name, e), "[bot]err:{}".format(e)) - break - - return reply diff --git a/pkg/qqbot~/process.py b/pkg/qqbot~/process.py deleted file mode 100644 index 3ca275a..0000000 --- a/pkg/qqbot~/process.py +++ /dev/null @@ -1,168 +0,0 @@ -# 此模块提供了消息处理的具体逻辑的接口 -import asyncio -import time - -import mirai -import logging - -from mirai import MessageChain, Plain - -# 这里不使用动态引入config -# 因为在这里动态引入会卡死程序 -# 而此模块静态引用config与动态引入的表现一致 -# 已弃用,由于超时时间现已动态使用 -# import config as config_init_import - -import pkg.openai.session -import pkg.openai.manager -import pkg.utils.reloader -import pkg.utils.updater -import pkg.utils.context -import pkg.qqbot.message -import pkg.qqbot.command -import pkg.qqbot.ratelimit as ratelimit - -import pkg.plugin.host as plugin_host -import pkg.plugin.models as plugin_models -import pkg.qqbot.ignore as ignore -import pkg.qqbot.banlist as banlist - -processing = [] - - -def is_admin(qq: int) -> bool: - """兼容list和int类型的管理员判断""" - import config - if type(config.admin_qq) == list: - return qq in config.admin_qq - else: - return qq == config.admin_qq - - -def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: MessageChain, - sender_id: int) -> MessageChain: - global processing - - mgr = pkg.utils.context.get_qqbot_manager() - - reply = [] - session_name = "{}_{}".format(launcher_type, launcher_id) - - # 检查发送方是否被禁用 - if banlist.is_banned(launcher_type, launcher_id, sender_id): - logging.info("根据禁用列表忽略{}_{}的消息".format(launcher_type, launcher_id)) - return [] - - if ignore.ignore(text_message): - logging.info("根据忽略规则忽略消息: {}".format(text_message)) - return [] - - # 检查是否被禁言 - if launcher_type == 'group': - result = mgr.bot.member_info(target=launcher_id, member_id=mgr.bot.qq).get() - result = asyncio.run(result) - if result.mute_time_remaining > 0: - logging.info("机器人被禁言,跳过消息处理(group_{},剩余{}s)".format(launcher_id, - result.mute_time_remaining)) - return reply - - import config - if hasattr(config, 'income_msg_check') and config.income_msg_check: - if mgr.reply_filter.is_illegal(text_message): - return MessageChain(Plain("[bot] 你的提问中有不合适的内容, 请更换措辞~")) - - pkg.openai.session.get_session(session_name).acquire_response_lock() - - text_message = text_message.strip() - - # 处理消息 - try: - if session_name in processing: - pkg.openai.session.get_session(session_name).release_response_lock() - return MessageChain([Plain("[bot]err:正在处理中,请稍后再试")]) - - config = pkg.utils.context.get_config() - - processing.append(session_name) - try: - if text_message.startswith('!') or text_message.startswith("!"): # 指令 - # 触发插件事件 - args = { - 'launcher_type': launcher_type, - 'launcher_id': launcher_id, - 'sender_id': sender_id, - 'command': text_message[1:].strip().split(' ')[0], - 'params': text_message[1:].strip().split(' ')[1:], - 'text_message': text_message, - 'is_admin': is_admin(sender_id), - } - event = plugin_host.emit(plugin_models.PersonCommandSent - if launcher_type == 'person' - else plugin_models.GroupCommandSent, **args) - - if event.get_return_value("alter") is not None: - text_message = event.get_return_value("alter") - - # 取出插件提交的返回值赋值给reply - if event.get_return_value("reply") is not None: - reply = event.get_return_value("reply") - - if not event.is_prevented_default(): - reply = pkg.qqbot.command.process_command(session_name, text_message, - mgr, config, launcher_type, launcher_id, sender_id, is_admin(sender_id)) - - else: # 消息 - # 限速丢弃检查 - # print(ratelimit.__crt_minute_usage__[session_name]) - if hasattr(config, "rate_limitation") and config.rate_limit_strategy == "drop": - if ratelimit.is_reach_limit(session_name): - logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) - return MessageChain(["[bot]"+config.rate_limit_drop_tip]) if hasattr(config, "rate_limit_drop_tip") and config.rate_limit_drop_tip != "" else [] - - before = time.time() - # 触发插件事件 - args = { - "launcher_type": launcher_type, - "launcher_id": launcher_id, - "sender_id": sender_id, - "text_message": text_message, - } - event = plugin_host.emit(plugin_models.PersonNormalMessageReceived - if launcher_type == 'person' - else plugin_models.GroupNormalMessageReceived, **args) - - if event.get_return_value("alter") is not None: - text_message = event.get_return_value("alter") - - # 取出插件提交的返回值赋值给reply - if event.get_return_value("reply") is not None: - reply = event.get_return_value("reply") - - if not event.is_prevented_default(): - reply = pkg.qqbot.message.process_normal_message(text_message, - mgr, config, launcher_type, launcher_id, sender_id) - - # 限速等待时间 - if hasattr(config, "rate_limitation") and config.rate_limit_strategy == "wait": - time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before)) - - if hasattr(config, "rate_limitation"): - ratelimit.add_usage(session_name) - - if reply is not None and len(reply) > 0 and (type(reply[0]) == str or type(reply[0]) == mirai.Plain): - if type(reply[0]) == mirai.Plain: - reply[0] = reply[0].text - logging.info( - "回复[{}]文字消息:{}".format(session_name, - reply[0][:min(100, len(reply[0]))] + ( - "..." if len(reply[0]) > 100 else ""))) - reply = [mgr.reply_filter.process(reply[0])] - else: - logging.info("回复[{}]消息".format(session_name)) - - finally: - processing.remove(session_name) - finally: - pkg.openai.session.get_session(session_name).release_response_lock() - - return MessageChain(reply) diff --git a/pkg/qqbot~/ratelimit.py b/pkg/qqbot~/ratelimit.py deleted file mode 100644 index 2a759b6..0000000 --- a/pkg/qqbot~/ratelimit.py +++ /dev/null @@ -1,86 +0,0 @@ -# 限速相关模块 -import time -import logging -import threading - -__crt_minute_usage__ = {} -"""当前分钟每个会话的对话次数""" - - -__timer_thr__: threading.Thread = None - - -def add_usage(session_name: str): - """增加会话的对话次数""" - global __crt_minute_usage__ - if session_name in __crt_minute_usage__: - __crt_minute_usage__[session_name] += 1 - else: - __crt_minute_usage__[session_name] = 1 - - -def start_timer(): - """启动定时器""" - global __timer_thr__ - __timer_thr__ = threading.Thread(target=run_timer, daemon=True) - __timer_thr__.start() - - -def run_timer(): - """启动定时器,每分钟清空一次对话次数""" - global __crt_minute_usage__ - global __timer_thr__ - - # 等待直到整分钟 - time.sleep(60 - time.time() % 60) - - while True: - if __timer_thr__ != threading.current_thread(): - break - - logging.debug("清空当前分钟的对话次数") - __crt_minute_usage__ = {} - time.sleep(60) - - -def get_usage(session_name: str) -> int: - """获取会话的对话次数""" - global __crt_minute_usage__ - if session_name in __crt_minute_usage__: - return __crt_minute_usage__[session_name] - else: - return 0 - - -def get_rest_wait_time(session_name: str, spent: float) -> float: - """获取会话此回合的剩余等待时间""" - global __crt_minute_usage__ - - import config - - if not hasattr(config, 'rate_limitation'): - return 0 - - min_seconds_per_round = 60.0 / config.rate_limitation - - if session_name in __crt_minute_usage__: - return max(0, min_seconds_per_round - spent) - else: - return 0 - - -def is_reach_limit(session_name: str) -> bool: - """判断会话是否超过限制""" - global __crt_minute_usage__ - - import config - - if not hasattr(config, 'rate_limitation'): - return False - - if session_name in __crt_minute_usage__: - return __crt_minute_usage__[session_name] >= config.rate_limitation - else: - return False - -start_timer()