mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
refactor: 修改引入风格
This commit is contained in:
parent
e3b280758c
commit
665de5dc43
|
@ -9,8 +9,8 @@ import threading
|
|||
|
||||
import requests
|
||||
|
||||
import pkg.utils.context
|
||||
import pkg.utils.updater
|
||||
from ..utils import context
|
||||
from ..utils import updater
|
||||
|
||||
|
||||
class DataGatherer:
|
||||
|
@ -33,7 +33,7 @@ class DataGatherer:
|
|||
def __init__(self):
|
||||
self.load_from_db()
|
||||
try:
|
||||
self.version_str = pkg.utils.updater.get_current_tag() # 从updater模块获取版本号
|
||||
self.version_str = updater.get_current_tag() # 从updater模块获取版本号
|
||||
except:
|
||||
pass
|
||||
|
||||
|
@ -47,7 +47,7 @@ class DataGatherer:
|
|||
def thread_func():
|
||||
|
||||
try:
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
if not config.report_usage:
|
||||
return
|
||||
res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config.msg_source_adapter))
|
||||
|
@ -64,7 +64,7 @@ class DataGatherer:
|
|||
def report_text_model_usage(self, model, total_tokens):
|
||||
"""调用方报告文字模型请求文字使用量"""
|
||||
|
||||
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存
|
||||
key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存
|
||||
|
||||
if key_md5 not in self.usage:
|
||||
self.usage[key_md5] = {}
|
||||
|
@ -84,7 +84,7 @@ class DataGatherer:
|
|||
def report_image_model_usage(self, size):
|
||||
"""调用方报告图片模型请求图片使用量"""
|
||||
|
||||
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()
|
||||
key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5()
|
||||
|
||||
if key_md5 not in self.usage:
|
||||
self.usage[key_md5] = {}
|
||||
|
@ -131,9 +131,9 @@ class DataGatherer:
|
|||
return total
|
||||
|
||||
def dump_to_db(self):
|
||||
pkg.utils.context.get_database_manager().dump_usage_json(self.usage)
|
||||
context.get_database_manager().dump_usage_json(self.usage)
|
||||
|
||||
def load_from_db(self):
|
||||
json_str = pkg.utils.context.get_database_manager().load_usage_json()
|
||||
json_str = context.get_database_manager().load_usage_json()
|
||||
if json_str is not None:
|
||||
self.usage = json.loads(json_str)
|
||||
|
|
|
@ -5,11 +5,10 @@ import hashlib
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from sqlite3 import Cursor
|
||||
|
||||
import sqlite3
|
||||
|
||||
import pkg.utils.context
|
||||
from ..utils import context
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
|
@ -22,7 +21,7 @@ class DatabaseManager:
|
|||
|
||||
self.reconnect()
|
||||
|
||||
pkg.utils.context.set_database_manager(self)
|
||||
context.set_database_manager(self)
|
||||
|
||||
# 连接到数据库文件
|
||||
def reconnect(self):
|
||||
|
@ -33,7 +32,7 @@ class DatabaseManager:
|
|||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def __execute__(self, *args, **kwargs) -> Cursor:
|
||||
def __execute__(self, *args, **kwargs) -> sqlite3.Cursor:
|
||||
# logging.debug('SQL: {}'.format(sql))
|
||||
logging.debug('SQL: {}'.format(args))
|
||||
c = self.cursor.execute(*args, **kwargs)
|
||||
|
@ -145,7 +144,7 @@ class DatabaseManager:
|
|||
# 从数据库加载还没过期的session数据
|
||||
def load_valid_sessions(self) -> dict:
|
||||
# 从数据库中加载所有还没过期的session
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
self.__execute__("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
|
||||
from `sessions` where `last_interact_timestamp` > {}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import openai
|
||||
from openai.types.chat import chat_completion_message
|
||||
import json
|
||||
import logging
|
||||
|
||||
from .model import RequestBase
|
||||
import openai
|
||||
from openai.types.chat import chat_completion_message
|
||||
|
||||
from ..funcmgr import get_func_schema_list, execute_function, get_func, get_func_schema, ContentFunctionNotFoundError
|
||||
from .model import RequestBase
|
||||
from .. import funcmgr
|
||||
|
||||
|
||||
class ChatCompletionRequest(RequestBase):
|
||||
|
@ -81,7 +81,7 @@ class ChatCompletionRequest(RequestBase):
|
|||
"messages": self.messages,
|
||||
}
|
||||
|
||||
funcs = get_func_schema_list()
|
||||
funcs = funcmgr.get_func_schema_list()
|
||||
|
||||
if len(funcs) > 0:
|
||||
args['functions'] = funcs
|
||||
|
@ -171,7 +171,7 @@ class ChatCompletionRequest(RequestBase):
|
|||
# 若不是json格式的异常处理
|
||||
except json.decoder.JSONDecodeError:
|
||||
# 获取函数的参数列表
|
||||
func_schema = get_func_schema(func_name)
|
||||
func_schema = funcmgr.get_func_schema(func_name)
|
||||
|
||||
arguments = {
|
||||
func_schema['parameters']['required'][0]: cp_pending_func_call.arguments
|
||||
|
@ -182,7 +182,7 @@ class ChatCompletionRequest(RequestBase):
|
|||
# 执行函数调用
|
||||
ret = ""
|
||||
try:
|
||||
ret = execute_function(func_name, arguments)
|
||||
ret = funcmgr.execute_function(func_name, arguments)
|
||||
|
||||
logging.info("函数执行完成。")
|
||||
except Exception as e:
|
||||
|
@ -216,6 +216,5 @@ class ChatCompletionRequest(RequestBase):
|
|||
}
|
||||
}
|
||||
|
||||
except ContentFunctionNotFoundError:
|
||||
except funcmgr.ContentFunctionNotFoundError:
|
||||
raise Exception("没有找到函数: {}".format(func_name))
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import openai
|
||||
from openai.types import completion, completion_choice
|
||||
|
||||
from .model import RequestBase
|
||||
from . import model
|
||||
|
||||
|
||||
class CompletionRequest(RequestBase):
|
||||
class CompletionRequest(model.RequestBase):
|
||||
"""调用Completion接口的请求类。
|
||||
|
||||
调用方可以一直next completion直到finish_reason为stop。
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
# 定义不同接口请求的模型
|
||||
import threading
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import openai
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# 封装了function calling的一些支持函数
|
||||
import logging
|
||||
|
||||
|
||||
from pkg.plugin import host
|
||||
from ..plugin import host
|
||||
|
||||
|
||||
class ContentFunctionNotFoundError(Exception):
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
import hashlib
|
||||
import logging
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
|
||||
|
||||
class KeysManager:
|
||||
|
|
|
@ -2,12 +2,11 @@ import logging
|
|||
|
||||
import openai
|
||||
|
||||
import pkg.openai.keymgr
|
||||
import pkg.utils.context
|
||||
import pkg.audit.gatherer
|
||||
from pkg.openai.modelmgr import select_request_cls
|
||||
|
||||
from pkg.openai.api.model import RequestBase
|
||||
from ..openai import keymgr
|
||||
from ..utils import context
|
||||
from ..audit import gatherer
|
||||
from ..openai import modelmgr
|
||||
from ..openai.api import model as api_model
|
||||
|
||||
|
||||
class OpenAIInteract:
|
||||
|
@ -16,9 +15,9 @@ class OpenAIInteract:
|
|||
将文字接口和图片接口封装供调用方使用
|
||||
"""
|
||||
|
||||
key_mgr: pkg.openai.keymgr.KeysManager = None
|
||||
key_mgr: keymgr.KeysManager = None
|
||||
|
||||
audit_mgr: pkg.audit.gatherer.DataGatherer = None
|
||||
audit_mgr: gatherer.DataGatherer = None
|
||||
|
||||
default_image_api_params = {
|
||||
"size": "256x256",
|
||||
|
@ -28,8 +27,8 @@ class OpenAIInteract:
|
|||
|
||||
def __init__(self, api_key: str):
|
||||
|
||||
self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
|
||||
self.audit_mgr = pkg.audit.gatherer.DataGatherer()
|
||||
self.key_mgr = keymgr.KeysManager(api_key)
|
||||
self.audit_mgr = gatherer.DataGatherer()
|
||||
|
||||
# logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length())
|
||||
|
||||
|
@ -37,22 +36,22 @@ class OpenAIInteract:
|
|||
api_key=self.key_mgr.get_using_key()
|
||||
)
|
||||
|
||||
pkg.utils.context.set_openai_manager(self)
|
||||
context.set_openai_manager(self)
|
||||
|
||||
def request_completion(self, messages: list):
|
||||
"""请求补全接口回复=
|
||||
"""
|
||||
# 选择接口请求类
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
|
||||
request: RequestBase
|
||||
request: api_model.RequestBase
|
||||
|
||||
model: str = config.completion_api_params['model']
|
||||
|
||||
cp_parmas = config.completion_api_params.copy()
|
||||
del cp_parmas['model']
|
||||
|
||||
request = select_request_cls(self.client, model, messages, cp_parmas)
|
||||
request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas)
|
||||
|
||||
# 请求接口
|
||||
for resp in request:
|
||||
|
@ -74,7 +73,7 @@ class OpenAIInteract:
|
|||
Returns:
|
||||
dict: 响应
|
||||
"""
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
params = config.image_api_params
|
||||
|
||||
response = openai.Image.create(
|
||||
|
|
|
@ -8,9 +8,9 @@ Completion - text-davinci-003 等模型
|
|||
import tiktoken
|
||||
import openai
|
||||
|
||||
from pkg.openai.api.model import RequestBase
|
||||
from pkg.openai.api.completion import CompletionRequest
|
||||
from pkg.openai.api.chat_completion import ChatCompletionRequest
|
||||
from ..openai.api import model as api_model
|
||||
from ..openai.api import completion as api_completion
|
||||
from ..openai.api import chat_completion as api_chat_completion
|
||||
|
||||
COMPLETION_MODELS = {
|
||||
"text-davinci-003", # legacy
|
||||
|
@ -60,11 +60,11 @@ IMAGE_MODELS = {
|
|||
}
|
||||
|
||||
|
||||
def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> RequestBase:
|
||||
def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> api_model.RequestBase:
|
||||
if model_name in CHAT_COMPLETION_MODELS:
|
||||
return ChatCompletionRequest(client, model_name, messages, **args)
|
||||
return api_chat_completion.ChatCompletionRequest(client, model_name, messages, **args)
|
||||
elif model_name in COMPLETION_MODELS:
|
||||
return CompletionRequest(client, model_name, messages, **args)
|
||||
return api_completion.CompletionRequest(client, model_name, messages, **args)
|
||||
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))
|
||||
|
||||
|
||||
|
|
|
@ -8,15 +8,13 @@ import threading
|
|||
import time
|
||||
import json
|
||||
|
||||
import pkg.openai.manager
|
||||
import pkg.openai.modelmgr
|
||||
import pkg.database.manager
|
||||
import pkg.utils.context
|
||||
from ..openai import manager as openai_manager
|
||||
from ..openai import modelmgr as openai_modelmgr
|
||||
from ..database import manager as database_manager
|
||||
from ..utils import context as context
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
|
||||
from pkg.openai.modelmgr import count_tokens
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
|
||||
# 运行时保存的所有session
|
||||
sessions = {}
|
||||
|
@ -38,7 +36,7 @@ def reset_session_prompt(session_name, prompt):
|
|||
f.write(prompt)
|
||||
f.close()
|
||||
# 生成新数据
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
prompt = [
|
||||
{
|
||||
'role': 'system',
|
||||
|
@ -61,7 +59,7 @@ def load_sessions():
|
|||
|
||||
global sessions
|
||||
|
||||
db_inst = pkg.utils.context.get_database_manager()
|
||||
db_inst = context.get_database_manager()
|
||||
|
||||
session_data = db_inst.load_valid_sessions()
|
||||
|
||||
|
@ -172,7 +170,7 @@ class Session:
|
|||
if self.create_timestamp != create_timestamp or self not in sessions.values():
|
||||
return
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
|
||||
logging.info('session {} 已过期'.format(self.name))
|
||||
|
||||
|
@ -182,7 +180,7 @@ class Session:
|
|||
'session': self,
|
||||
'session_expire_time': config.session_expire_time
|
||||
}
|
||||
event = pkg.plugin.host.emit(plugin_models.SessionExpired, **args)
|
||||
event = plugin_host.emit(plugin_models.SessionExpired, **args)
|
||||
if event.is_prevented_default():
|
||||
return
|
||||
|
||||
|
@ -214,11 +212,11 @@ class Session:
|
|||
'default_prompt': self.default_prompt,
|
||||
}
|
||||
|
||||
event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args)
|
||||
event = plugin_host.emit(plugin_models.SessionFirstMessageReceived, **args)
|
||||
if event.is_prevented_default():
|
||||
return None, None, None
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
max_length = config.prompt_submit_length
|
||||
|
||||
local_default_prompt = self.default_prompt.copy()
|
||||
|
@ -232,7 +230,7 @@ class Session:
|
|||
'text_message': text,
|
||||
}
|
||||
|
||||
event = pkg.plugin.host.emit(plugin_models.PromptPreProcessing, **args)
|
||||
event = plugin_host.emit(plugin_models.PromptPreProcessing, **args)
|
||||
|
||||
if event.get_return_value('default_prompt') is not None:
|
||||
local_default_prompt = event.get_return_value('default_prompt')
|
||||
|
@ -256,14 +254,14 @@ class Session:
|
|||
funcs = []
|
||||
|
||||
trace_func_calls = config.trace_function_calls
|
||||
botmgr = pkg.utils.context.get_qqbot_manager()
|
||||
botmgr = context.get_qqbot_manager()
|
||||
|
||||
session_name_spt: list[str] = self.name.split("_")
|
||||
|
||||
pending_res_text = ""
|
||||
|
||||
# TODO 对不起,我知道这样非常非常屎山,但我之后会重构的
|
||||
for resp in pkg.utils.context.get_openai_manager().request_completion(prompts):
|
||||
for resp in context.get_openai_manager().request_completion(prompts):
|
||||
|
||||
if pending_res_text != "":
|
||||
botmgr.adapter.send_message(
|
||||
|
@ -325,7 +323,6 @@ class Session:
|
|||
)
|
||||
pass
|
||||
|
||||
|
||||
# 向API请求补全
|
||||
# message, total_token = pkg.utils.context.get_openai_manager().request_completion(
|
||||
# prompts,
|
||||
|
@ -383,13 +380,13 @@ class Session:
|
|||
# 包装目前的对话回合内容
|
||||
changable_prompts = []
|
||||
|
||||
use_model = pkg.utils.context.get_config().completion_api_params['model']
|
||||
use_model = context.get_config().completion_api_params['model']
|
||||
|
||||
ptr = len(prompt) - 1
|
||||
|
||||
# 直接从后向前扫描拼接,不管是否是整回合
|
||||
while ptr >= 0:
|
||||
if count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
|
||||
if openai_modelmgr.count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
|
||||
break
|
||||
|
||||
changable_prompts.insert(0, prompt[ptr])
|
||||
|
@ -410,14 +407,14 @@ class Session:
|
|||
|
||||
logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4)))
|
||||
|
||||
return result_prompt, count_tokens(changable_prompts, use_model)
|
||||
return result_prompt, openai_modelmgr.count_tokens(changable_prompts, use_model)
|
||||
|
||||
# 持久化session
|
||||
def persistence(self):
|
||||
if self.prompt == self.get_default_prompt():
|
||||
return
|
||||
|
||||
db_inst = pkg.utils.context.get_database_manager()
|
||||
db_inst = context.get_database_manager()
|
||||
|
||||
name_spt = self.name.split('_')
|
||||
|
||||
|
@ -439,12 +436,12 @@ class Session:
|
|||
}
|
||||
|
||||
# 此事件不支持阻止默认行为
|
||||
_ = pkg.plugin.host.emit(plugin_models.SessionExplicitReset, **args)
|
||||
_ = plugin_host.emit(plugin_models.SessionExplicitReset, **args)
|
||||
|
||||
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
|
||||
context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
|
||||
|
||||
if expired:
|
||||
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
|
||||
context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
|
||||
|
||||
if not persist: # 不要求保持default prompt
|
||||
self.default_prompt = self.get_default_prompt(use_prompt)
|
||||
|
@ -461,11 +458,11 @@ class Session:
|
|||
|
||||
# 将本session的数据库状态设置为on_going
|
||||
def set_ongoing(self):
|
||||
pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
|
||||
context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
|
||||
|
||||
# 切换到上一个session
|
||||
def last_session(self):
|
||||
last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
|
||||
last_one = context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
|
||||
if last_one is None:
|
||||
return None
|
||||
else:
|
||||
|
@ -486,7 +483,7 @@ class Session:
|
|||
|
||||
# 切换到下一个session
|
||||
def next_session(self):
|
||||
next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
|
||||
next_one = context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
|
||||
if next_one is None:
|
||||
return None
|
||||
else:
|
||||
|
@ -506,13 +503,13 @@ class Session:
|
|||
return self
|
||||
|
||||
def list_history(self, capacity: int = 10, page: int = 0):
|
||||
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page)
|
||||
return context.get_database_manager().list_history(self.name, capacity, page)
|
||||
|
||||
def delete_history(self, index: int) -> bool:
|
||||
return pkg.utils.context.get_database_manager().delete_history(self.name, index)
|
||||
return context.get_database_manager().delete_history(self.name, index)
|
||||
|
||||
def delete_all_history(self) -> bool:
|
||||
return pkg.utils.context.get_database_manager().delete_all_history(self.name)
|
||||
return context.get_database_manager().delete_all_history(self.name)
|
||||
|
||||
def draw_image(self, prompt: str):
|
||||
return pkg.utils.context.get_openai_manager().request_image(prompt)
|
||||
return context.get_openai_manager().request_image(prompt)
|
||||
|
|
|
@ -10,13 +10,13 @@ import traceback
|
|||
import time
|
||||
import re
|
||||
|
||||
import pkg.utils.updater as updater
|
||||
import pkg.utils.context as context
|
||||
import pkg.plugin.switch as switch
|
||||
import pkg.plugin.settings as settings
|
||||
import pkg.qqbot.adapter as msadapter
|
||||
import pkg.utils.network as network
|
||||
import pkg.plugin.metadata as metadata
|
||||
from ..utils import updater as updater
|
||||
from ..utils import network as network
|
||||
from ..utils import context as context
|
||||
from ..plugin import switch as switch
|
||||
from ..plugin import settings as settings
|
||||
from ..qqbot import adapter as msadapter
|
||||
from ..plugin import metadata as metadata
|
||||
|
||||
from mirai import Mirai
|
||||
import requests
|
||||
|
@ -147,6 +147,7 @@ def initialize_plugins():
|
|||
successfully_initialized_plugins.append(plugin['name'])
|
||||
except:
|
||||
logging.error("插件{}初始化时发生错误: {}".format(plugin['name'], sys.exc_info()))
|
||||
logging.debug(traceback.format_exc())
|
||||
|
||||
logging.info("以下插件已初始化: {}".format(", ".join(successfully_initialized_plugins)))
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
|
||||
import pkg.plugin.host as host
|
||||
import pkg.utils.context
|
||||
from ..plugin import host
|
||||
from ..utils import context
|
||||
|
||||
PersonMessageReceived = "person_message_received"
|
||||
"""收到私聊消息时,在判断是否应该响应前触发
|
||||
|
@ -285,7 +285,7 @@ def register(name: str, description: str, version: str, author: str):
|
|||
cls.description = description
|
||||
cls.version = version
|
||||
cls.author = author
|
||||
cls.host = pkg.utils.context.get_plugin_host()
|
||||
cls.host = context.get_plugin_host()
|
||||
cls.enabled = True
|
||||
cls.path = host.__current_module_path__
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import pkg.plugin.host as host
|
||||
import logging
|
||||
|
||||
from ..plugin import host
|
||||
|
||||
def wrapper_dict_from_runtime_context() -> dict:
|
||||
"""从变量中包装settings.json的数据字典"""
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
|
||||
import pkg.plugin.host as host
|
||||
from ..plugin import host
|
||||
|
||||
|
||||
def wrapper_dict_from_plugin_list() -> dict:
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
import pkg.utils.context
|
||||
from ..utils import context
|
||||
|
||||
|
||||
def is_banned(launcher_type: str, launcher_id: int, sender_id: int) -> bool:
|
||||
if not pkg.utils.context.get_qqbot_manager().enable_banlist:
|
||||
if not context.get_qqbot_manager().enable_banlist:
|
||||
return False
|
||||
|
||||
result = False
|
||||
|
||||
if launcher_type == 'group':
|
||||
# 检查是否显式声明发起人QQ要被person忽略
|
||||
if sender_id in pkg.utils.context.get_qqbot_manager().ban_person:
|
||||
if sender_id in context.get_qqbot_manager().ban_person:
|
||||
result = True
|
||||
else:
|
||||
for group_rule in pkg.utils.context.get_qqbot_manager().ban_group:
|
||||
for group_rule in context.get_qqbot_manager().ban_group:
|
||||
if type(group_rule) == int:
|
||||
if group_rule == launcher_id: # 此群群号被禁用
|
||||
result = True
|
||||
|
@ -32,7 +32,7 @@ def is_banned(launcher_type: str, launcher_id: int, sender_id: int) -> bool:
|
|||
|
||||
else:
|
||||
# ban_person, 与群规则相同
|
||||
for person_rule in pkg.utils.context.get_qqbot_manager().ban_person:
|
||||
for person_rule in context.get_qqbot_manager().ban_person:
|
||||
if type(person_rule) == int:
|
||||
if person_rule == launcher_id:
|
||||
result = True
|
||||
|
|
|
@ -2,21 +2,21 @@
|
|||
import os
|
||||
import time
|
||||
import base64
|
||||
import typing
|
||||
|
||||
import config
|
||||
from mirai.models.message import MessageComponent, MessageChain, Image
|
||||
from mirai.models.message import ForwardMessageNode
|
||||
from mirai.models.base import MiraiBaseModel
|
||||
from typing import List
|
||||
import pkg.utils.context as context
|
||||
import pkg.utils.text2img as text2img
|
||||
|
||||
from ..utils import text2img
|
||||
import config
|
||||
|
||||
|
||||
class ForwardMessageDiaplay(MiraiBaseModel):
|
||||
title: str = "群聊的聊天记录"
|
||||
brief: str = "[聊天记录]"
|
||||
source: str = "聊天记录"
|
||||
preview: List[str] = []
|
||||
preview: typing.List[str] = []
|
||||
summary: str = "查看x条转发消息"
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ class Forward(MessageComponent):
|
|||
"""消息组件类型。"""
|
||||
display: ForwardMessageDiaplay
|
||||
"""显示信息"""
|
||||
node_list: List[ForwardMessageNode]
|
||||
node_list: typing.List[ForwardMessageNode]
|
||||
"""转发消息节点列表。"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if len(args) == 1:
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import copy
|
||||
import pkgutil
|
||||
import traceback
|
||||
import types
|
||||
import json
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import logging
|
||||
|
||||
from mirai import Image
|
||||
import mirai
|
||||
|
||||
from .. import aamgr
|
||||
import config
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="draw",
|
||||
description="使用DALL·E生成图片",
|
||||
|
@ -13,9 +14,9 @@ import config
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class DrawCommand(AbstractCommandNode):
|
||||
class DrawCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
|
||||
reply = []
|
||||
|
@ -28,7 +29,7 @@ class DrawCommand(AbstractCommandNode):
|
|||
res = session.draw_image(" ".join(ctx.params))
|
||||
|
||||
logging.debug("draw_image result:{}".format(res))
|
||||
reply = [Image(url=res['data'][0]['url'])]
|
||||
reply = [mirai.Image(url=res['data'][0]['url'])]
|
||||
if not (hasattr(config, 'include_image_description')
|
||||
and not config.include_image_description):
|
||||
reply.append(" ".join(ctx.params))
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import logging
|
||||
|
||||
import json
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="func",
|
||||
description="管理内容函数",
|
||||
|
@ -12,9 +11,9 @@ import json
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class FuncCommand(AbstractCommandNode):
|
||||
class FuncCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
from pkg.plugin.models import host
|
||||
|
||||
reply = []
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
|
||||
import os
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.utils.updater as updater
|
||||
from ....plugin import host as plugin_host
|
||||
from ....utils import updater
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="plugin",
|
||||
description="插件管理",
|
||||
|
@ -14,9 +11,9 @@ import pkg.utils.updater as updater
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class PluginCommand(AbstractCommandNode):
|
||||
class PluginCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
reply = []
|
||||
plugin_list = plugin_host.__plugins__
|
||||
if len(ctx.params) == 0:
|
||||
|
@ -48,7 +45,7 @@ class PluginCommand(AbstractCommandNode):
|
|||
return False, []
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=PluginCommand,
|
||||
name="get",
|
||||
description="安装插件",
|
||||
|
@ -56,9 +53,9 @@ class PluginCommand(AbstractCommandNode):
|
|||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class PluginGetCommand(AbstractCommandNode):
|
||||
class PluginGetCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import threading
|
||||
import logging
|
||||
import pkg.utils.context
|
||||
|
@ -81,7 +78,7 @@ class PluginGetCommand(AbstractCommandNode):
|
|||
return True, reply
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=PluginCommand,
|
||||
name="update",
|
||||
description="更新指定插件或全部插件",
|
||||
|
@ -89,9 +86,9 @@ class PluginGetCommand(AbstractCommandNode):
|
|||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class PluginUpdateCommand(AbstractCommandNode):
|
||||
class PluginUpdateCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import threading
|
||||
import logging
|
||||
plugin_list = plugin_host.__plugins__
|
||||
|
@ -130,7 +127,7 @@ class PluginUpdateCommand(AbstractCommandNode):
|
|||
return True, reply
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=PluginCommand,
|
||||
name="del",
|
||||
description="删除插件",
|
||||
|
@ -138,9 +135,9 @@ class PluginUpdateCommand(AbstractCommandNode):
|
|||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class PluginDelCommand(AbstractCommandNode):
|
||||
class PluginDelCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
plugin_list = plugin_host.__plugins__
|
||||
reply = []
|
||||
|
||||
|
@ -157,7 +154,7 @@ class PluginDelCommand(AbstractCommandNode):
|
|||
return True, reply
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=PluginCommand,
|
||||
name="on",
|
||||
description="启用指定插件",
|
||||
|
@ -165,7 +162,7 @@ class PluginDelCommand(AbstractCommandNode):
|
|||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=PluginCommand,
|
||||
name="off",
|
||||
description="禁用指定插件",
|
||||
|
@ -173,9 +170,9 @@ class PluginDelCommand(AbstractCommandNode):
|
|||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class PluginOnOffCommand(AbstractCommandNode):
|
||||
class PluginOnOffCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.plugin.switch as plugin_switch
|
||||
|
||||
plugin_list = plugin_host.__plugins__
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="default",
|
||||
description="操作情景预设",
|
||||
|
@ -9,9 +8,9 @@ from ..aamgr import AbstractCommandNode, Context
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class DefaultCommand(AbstractCommandNode):
|
||||
class DefaultCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
params = ctx.params
|
||||
|
@ -45,7 +44,7 @@ class DefaultCommand(AbstractCommandNode):
|
|||
return True, reply
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=DefaultCommand,
|
||||
name="set",
|
||||
description="设置默认情景预设",
|
||||
|
@ -53,9 +52,9 @@ class DefaultCommand(AbstractCommandNode):
|
|||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class DefaultSetCommand(AbstractCommandNode):
|
||||
class DefaultSetCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
reply = []
|
||||
|
||||
if len(ctx.crt_params) == 0:
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import datetime
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="del",
|
||||
description="删除当前会话的历史记录",
|
||||
|
@ -10,9 +9,9 @@ import datetime
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class DelCommand(AbstractCommandNode):
|
||||
class DelCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
params = ctx.params
|
||||
|
@ -33,7 +32,7 @@ class DelCommand(AbstractCommandNode):
|
|||
return True, reply
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=DelCommand,
|
||||
name="all",
|
||||
description="删除当前会话的全部历史记录",
|
||||
|
@ -41,9 +40,9 @@ class DelCommand(AbstractCommandNode):
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class DelAllCommand(AbstractCommandNode):
|
||||
class DelAllCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
reply = []
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="delhst",
|
||||
description="删除指定会话的所有历史记录",
|
||||
|
@ -9,9 +9,9 @@ from ..aamgr import AbstractCommandNode, Context
|
|||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class DelHistoryCommand(AbstractCommandNode):
|
||||
class DelHistoryCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
import pkg.utils.context
|
||||
params = ctx.params
|
||||
|
@ -31,7 +31,7 @@ class DelHistoryCommand(AbstractCommandNode):
|
|||
return True, reply
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=DelHistoryCommand,
|
||||
name="all",
|
||||
description="删除所有会话的全部历史记录",
|
||||
|
@ -39,9 +39,9 @@ class DelHistoryCommand(AbstractCommandNode):
|
|||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class DelAllHistoryCommand(AbstractCommandNode):
|
||||
class DelAllHistoryCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.utils.context
|
||||
reply = []
|
||||
pkg.utils.context.get_database_manager().delete_all_session_history()
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import datetime
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="last",
|
||||
description="切换前一次对话",
|
||||
|
@ -10,9 +11,9 @@ import datetime
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class LastCommand(AbstractCommandNode):
|
||||
class LastCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name='list',
|
||||
description='列出当前会话的所有历史记录',
|
||||
|
@ -11,9 +12,9 @@ import json
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class ListCommand(AbstractCommandNode):
|
||||
class ListCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
params = ctx.params
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import datetime
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="next",
|
||||
description="切换后一次对话",
|
||||
|
@ -10,9 +11,9 @@ import datetime
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class NextCommand(AbstractCommandNode):
|
||||
class NextCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
reply = []
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import datetime
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="prompt",
|
||||
description="获取当前会话的前文",
|
||||
|
@ -10,9 +9,9 @@ import datetime
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class PromptCommand(AbstractCommandNode):
|
||||
class PromptCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
session_name = ctx.session_name
|
||||
params = ctx.params
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import datetime
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="resend",
|
||||
description="重新获取上一次问题的回复",
|
||||
|
@ -10,20 +9,22 @@ import datetime
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class ResendCommand(AbstractCommandNode):
|
||||
class ResendCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
import pkg.openai.session
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
from ....openai import session as openai_session
|
||||
from ....utils import context
|
||||
from ....qqbot import message
|
||||
import config
|
||||
session_name = ctx.session_name
|
||||
reply = []
|
||||
|
||||
session = pkg.openai.session.get_session(session_name)
|
||||
session = openai_session.get_session(session_name)
|
||||
to_send = session.undo()
|
||||
|
||||
mgr = pkg.utils.context.get_qqbot_manager()
|
||||
mgr = context.get_qqbot_manager()
|
||||
|
||||
reply = pkg.qqbot.message.process_normal_message(to_send, mgr, config,
|
||||
reply = message.process_normal_message(to_send, mgr, config,
|
||||
ctx.launcher_type, ctx.launcher_id,
|
||||
ctx.sender_id)
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import tips as tips_custom
|
||||
|
||||
import pkg.openai.session
|
||||
import pkg.utils.context
|
||||
from .. import aamgr
|
||||
from ....openai import session
|
||||
from ....utils import context
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name='reset',
|
||||
description='重置当前会话',
|
||||
|
@ -13,21 +13,21 @@ import pkg.utils.context
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class ResetCommand(AbstractCommandNode):
|
||||
class ResetCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
params = ctx.params
|
||||
session_name = ctx.session_name
|
||||
|
||||
reply = ""
|
||||
|
||||
if len(params) == 0:
|
||||
pkg.openai.session.get_session(session_name).reset(explicit=True)
|
||||
session.get_session(session_name).reset(explicit=True)
|
||||
reply = [tips_custom.command_reset_message]
|
||||
else:
|
||||
try:
|
||||
import pkg.openai.dprompt as dprompt
|
||||
pkg.openai.session.get_session(session_name).reset(explicit=True, use_prompt=params[0])
|
||||
session.get_session(session_name).reset(explicit=True, use_prompt=params[0])
|
||||
reply = [tips_custom.command_reset_name_message+"{}".format(dprompt.mode_inst().get_full_name(params[0]))]
|
||||
except Exception as e:
|
||||
reply = ["[bot]会话重置失败:{}".format(e)]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import json
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
def config_operation(cmd, params):
|
||||
reply = []
|
||||
|
@ -85,7 +86,7 @@ def config_operation(cmd, params):
|
|||
return reply
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="cfg",
|
||||
description="配置项管理",
|
||||
|
@ -93,8 +94,8 @@ def config_operation(cmd, params):
|
|||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class CfgCommand(AbstractCommandNode):
|
||||
class CfgCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
return True, config_operation(ctx.command, ctx.params)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from ..aamgr import AbstractCommandNode, Context, __command_list__
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="cmd",
|
||||
description="显示指令列表",
|
||||
|
@ -9,10 +9,10 @@ from ..aamgr import AbstractCommandNode, Context, __command_list__
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class CmdCommand(AbstractCommandNode):
|
||||
class CmdCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
command_list = __command_list__
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
command_list = aamgr.__command_list__
|
||||
|
||||
reply = []
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="help",
|
||||
description="显示自定义的帮助信息",
|
||||
|
@ -9,9 +9,9 @@ from ..aamgr import AbstractCommandNode, Context
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class HelpCommand(AbstractCommandNode):
|
||||
class HelpCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import tips
|
||||
reply = ["[bot] "+tips.help_message + "\n请输入 !cmd 查看指令列表"]
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import threading
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="reload",
|
||||
description="执行热重载",
|
||||
|
@ -9,9 +11,9 @@ import threading
|
|||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class ReloadCommand(AbstractCommandNode):
|
||||
class ReloadCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
reply = []
|
||||
|
||||
import pkg.utils.reloader
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from .. import aamgr
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="update",
|
||||
description="更新程序",
|
||||
|
@ -11,9 +12,9 @@ import traceback
|
|||
aliases=[],
|
||||
privilege=2
|
||||
)
|
||||
class UpdateCommand(AbstractCommandNode):
|
||||
class UpdateCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
reply = []
|
||||
import pkg.utils.updater
|
||||
import pkg.utils.reloader
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
import logging
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="usage",
|
||||
description="获取使用情况",
|
||||
|
@ -10,9 +9,9 @@ import logging
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class UsageCommand(AbstractCommandNode):
|
||||
class UsageCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
import config
|
||||
import pkg.utils.credit as credit
|
||||
import pkg.utils.context
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from ..aamgr import AbstractCommandNode, Context
|
||||
from .. import aamgr
|
||||
|
||||
|
||||
@AbstractCommandNode.register(
|
||||
@aamgr.AbstractCommandNode.register(
|
||||
parent=None,
|
||||
name="version",
|
||||
description="查看版本信息",
|
||||
|
@ -9,9 +9,9 @@ from ..aamgr import AbstractCommandNode, Context
|
|||
aliases=[],
|
||||
privilege=1
|
||||
)
|
||||
class VersionCommand(AbstractCommandNode):
|
||||
class VersionCommand(aamgr.AbstractCommandNode):
|
||||
@classmethod
|
||||
def process(cls, ctx: Context) -> tuple[bool, list]:
|
||||
def process(cls, ctx: aamgr.Context) -> tuple[bool, list]:
|
||||
reply = []
|
||||
import pkg.utils.updater
|
||||
|
||||
|
|
|
@ -1,23 +1,7 @@
|
|||
# 指令处理模块
|
||||
import logging
|
||||
import json
|
||||
import datetime
|
||||
import os
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import pkg.openai.session
|
||||
import pkg.openai.manager
|
||||
import pkg.utils.reloader
|
||||
import pkg.utils.updater
|
||||
import pkg.utils.context
|
||||
import pkg.qqbot.message
|
||||
import pkg.utils.credit as credit
|
||||
# import pkg.qqbot.cmds.model as cmdmodel
|
||||
import pkg.qqbot.cmds.aamgr as cmdmgr
|
||||
|
||||
from mirai import Image
|
||||
|
||||
from ..qqbot.cmds import aamgr as cmdmgr
|
||||
|
||||
|
||||
def process_command(session_name: str, text_message: str, mgr, config,
|
||||
|
|
|
@ -1,32 +1,25 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
|
||||
import logging
|
||||
|
||||
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
||||
FriendMessage, Image, MessageChain, Plain
|
||||
from func_timeout import func_set_timeout
|
||||
import func_timeout
|
||||
|
||||
import pkg.openai.session
|
||||
import pkg.openai.manager
|
||||
from func_timeout import FunctionTimedOut
|
||||
import logging
|
||||
from ..openai import session as openai_session
|
||||
|
||||
import pkg.qqbot.filter
|
||||
import pkg.qqbot.process as processor
|
||||
import pkg.utils.context
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
from ..qqbot import filter as qqbot_filter
|
||||
from ..qqbot import process as processor
|
||||
from ..utils import context
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
import tips as tips_custom
|
||||
|
||||
import pkg.qqbot.adapter as msadapter
|
||||
from ..qqbot import adapter as msadapter
|
||||
|
||||
|
||||
# 检查消息是否符合泛响应匹配机制
|
||||
def check_response_rule(group_id:int, text: str):
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
|
||||
rules = config.response_rules
|
||||
|
||||
|
@ -55,7 +48,7 @@ def check_response_rule(group_id:int, text: str):
|
|||
|
||||
|
||||
def response_at(group_id: int):
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
|
||||
use_response_rule = config.response_rules
|
||||
|
||||
|
@ -73,7 +66,7 @@ def response_at(group_id: int):
|
|||
|
||||
|
||||
def random_responding(group_id):
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
|
||||
use_response_rule = config.response_rules
|
||||
|
||||
|
@ -130,10 +123,10 @@ class QQBotManager:
|
|||
self.adapter = NakuruProjectAdapter(config.nakuru_config)
|
||||
self.bot_account_id = self.adapter.bot_account_id
|
||||
else:
|
||||
self.adapter = pkg.utils.context.get_qqbot_manager().adapter
|
||||
self.bot_account_id = pkg.utils.context.get_qqbot_manager().bot_account_id
|
||||
self.adapter = context.get_qqbot_manager().adapter
|
||||
self.bot_account_id = context.get_qqbot_manager().bot_account_id
|
||||
|
||||
pkg.utils.context.set_qqbot_manager(self)
|
||||
context.set_qqbot_manager(self)
|
||||
|
||||
# 注册诸事件
|
||||
# Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码
|
||||
|
@ -154,7 +147,7 @@ class QQBotManager:
|
|||
|
||||
self.on_person_message(event)
|
||||
|
||||
pkg.utils.context.get_thread_ctl().submit_user_task(
|
||||
context.get_thread_ctl().submit_user_task(
|
||||
friend_message_handler,
|
||||
)
|
||||
self.adapter.register_listener(
|
||||
|
@ -179,7 +172,7 @@ class QQBotManager:
|
|||
|
||||
self.on_person_message(event)
|
||||
|
||||
pkg.utils.context.get_thread_ctl().submit_user_task(
|
||||
context.get_thread_ctl().submit_user_task(
|
||||
stranger_message_handler,
|
||||
)
|
||||
# nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件
|
||||
|
@ -206,7 +199,7 @@ class QQBotManager:
|
|||
|
||||
self.on_group_message(event)
|
||||
|
||||
pkg.utils.context.get_thread_ctl().submit_user_task(
|
||||
context.get_thread_ctl().submit_user_task(
|
||||
group_message_handler,
|
||||
event
|
||||
)
|
||||
|
@ -250,22 +243,22 @@ class QQBotManager:
|
|||
if hasattr(banlist, "enable_group"):
|
||||
self.enable_group = banlist.enable_group
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
if os.path.exists("sensitive.json") \
|
||||
and config.sensitive_word_filter is not None \
|
||||
and config.sensitive_word_filter:
|
||||
with open("sensitive.json", "r", encoding="utf-8") as f:
|
||||
sensitive_json = json.load(f)
|
||||
self.reply_filter = pkg.qqbot.filter.ReplyFilter(
|
||||
self.reply_filter = qqbot_filter.ReplyFilter(
|
||||
sensitive_words=sensitive_json['words'],
|
||||
mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*',
|
||||
mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else ''
|
||||
)
|
||||
else:
|
||||
self.reply_filter = pkg.qqbot.filter.ReplyFilter([])
|
||||
self.reply_filter = qqbot_filter.ReplyFilter([])
|
||||
|
||||
def send(self, event, msg, check_quote=True, check_at_sender=True):
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
|
||||
if check_at_sender and config.at_sender:
|
||||
msg.insert(
|
||||
|
@ -306,7 +299,7 @@ class QQBotManager:
|
|||
for i in range(self.retry):
|
||||
try:
|
||||
|
||||
@func_set_timeout(config.process_message_timeout)
|
||||
@func_timeout.func_set_timeout(config.process_message_timeout)
|
||||
def time_ctrl_wrapper():
|
||||
reply = processor.process_message('person', event.sender.id, str(event.message_chain),
|
||||
event.message_chain,
|
||||
|
@ -315,16 +308,16 @@ class QQBotManager:
|
|||
|
||||
reply = time_ctrl_wrapper()
|
||||
break
|
||||
except FunctionTimedOut:
|
||||
except func_timeout.FunctionTimedOut:
|
||||
logging.warning("person_{}: 超时,重试中({})".format(event.sender.id, i))
|
||||
pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
|
||||
if "person_{}".format(event.sender.id) in pkg.qqbot.process.processing:
|
||||
pkg.qqbot.process.processing.remove('person_{}'.format(event.sender.id))
|
||||
openai_session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
|
||||
if "person_{}".format(event.sender.id) in processor.processing:
|
||||
processor.processing.remove('person_{}'.format(event.sender.id))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
if failed == self.retry:
|
||||
pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
|
||||
openai_session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
|
||||
self.notify_admin("{} 请求超时".format("person_{}".format(event.sender.id)))
|
||||
reply = [tips_custom.reply_message]
|
||||
|
||||
|
@ -344,7 +337,7 @@ class QQBotManager:
|
|||
failed = 0
|
||||
for i in range(self.retry):
|
||||
try:
|
||||
@func_set_timeout(config.process_message_timeout)
|
||||
@func_timeout.func_set_timeout(config.process_message_timeout)
|
||||
def time_ctrl_wrapper():
|
||||
replys = processor.process_message('group', event.group.id,
|
||||
str(event.message_chain).strip() if text is None else text,
|
||||
|
@ -354,16 +347,16 @@ class QQBotManager:
|
|||
|
||||
replys = time_ctrl_wrapper()
|
||||
break
|
||||
except FunctionTimedOut:
|
||||
except func_timeout.FunctionTimedOut:
|
||||
logging.warning("group_{}: 超时,重试中({})".format(event.group.id, i))
|
||||
pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock()
|
||||
if "group_{}".format(event.group.id) in pkg.qqbot.process.processing:
|
||||
pkg.qqbot.process.processing.remove('group_{}'.format(event.group.id))
|
||||
openai_session.get_session('group_{}'.format(event.group.id)).release_response_lock()
|
||||
if "group_{}".format(event.group.id) in processor.processing:
|
||||
processor.processing.remove('group_{}'.format(event.group.id))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
if failed == self.retry:
|
||||
pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock()
|
||||
openai_session.get_session('group_{}'.format(event.group.id)).release_response_lock()
|
||||
self.notify_admin("{} 请求超时".format("group_{}".format(event.group.id)))
|
||||
replys = [tips_custom.replys_message]
|
||||
|
||||
|
@ -392,7 +385,7 @@ class QQBotManager:
|
|||
|
||||
# 通知系统管理员
|
||||
def notify_admin(self, message: str):
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
if config.admin_qq != 0 and config.admin_qq != []:
|
||||
logging.info("通知管理员:{}".format(message))
|
||||
if type(config.admin_qq) == int:
|
||||
|
@ -410,7 +403,7 @@ class QQBotManager:
|
|||
)
|
||||
|
||||
def notify_admin_message_chain(self, message):
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
if config.admin_qq != 0 and config.admin_qq != []:
|
||||
logging.info("通知管理员:{}".format(message))
|
||||
if type(config.admin_qq) == int:
|
||||
|
|
|
@ -1,19 +1,20 @@
|
|||
# 普通消息处理模块
|
||||
import logging
|
||||
import openai
|
||||
import pkg.utils.context
|
||||
import pkg.openai.session
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
import pkg.qqbot.blob as blob
|
||||
import openai
|
||||
|
||||
from ..utils import context
|
||||
from ..openai import session as openai_session
|
||||
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
import tips as tips_custom
|
||||
|
||||
|
||||
def handle_exception(notify_admin: str = "", set_reply: str = "") -> list:
|
||||
"""处理异常,当notify_admin不为空时,会通知管理员,返回通知用户的消息"""
|
||||
import config
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin(notify_admin)
|
||||
context.get_qqbot_manager().notify_admin(notify_admin)
|
||||
if config.hide_exce_info_to_user:
|
||||
return [tips_custom.alter_tip_message] if tips_custom.alter_tip_message else []
|
||||
else:
|
||||
|
@ -26,7 +27,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str,
|
|||
logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + (
|
||||
"..." if len(text_message) > 20 else "")))
|
||||
|
||||
session = pkg.openai.session.get_session(session_name)
|
||||
session = openai_session.get_session(session_name)
|
||||
|
||||
unexpected_exception_times = 0
|
||||
|
||||
|
@ -54,7 +55,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str,
|
|||
"funcs_called": funcs,
|
||||
}
|
||||
|
||||
event = pkg.plugin.host.emit(plugin_models.NormalMessageResponded, **args)
|
||||
event = plugin_host.emit(plugin_models.NormalMessageResponded, **args)
|
||||
|
||||
if event.get_return_value("prefix") is not None:
|
||||
prefix = event.get_return_value("prefix")
|
||||
|
@ -78,29 +79,29 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str,
|
|||
|
||||
if 'message' in e.error and e.error['message'].__contains__('You exceeded your current quota'):
|
||||
# 尝试切换api-key
|
||||
current_key_name = pkg.utils.context.get_openai_manager().key_mgr.get_key_name(
|
||||
pkg.utils.context.get_openai_manager().key_mgr.using_key
|
||||
current_key_name = context.get_openai_manager().key_mgr.get_key_name(
|
||||
context.get_openai_manager().key_mgr.using_key
|
||||
)
|
||||
pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded()
|
||||
context.get_openai_manager().key_mgr.set_current_exceeded()
|
||||
|
||||
# 触发插件事件
|
||||
args = {
|
||||
'key_name': current_key_name,
|
||||
'usage': pkg.utils.context.get_openai_manager().audit_mgr
|
||||
.get_usage(pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()),
|
||||
'exceeded_keys': pkg.utils.context.get_openai_manager().key_mgr.exceeded,
|
||||
'usage': context.get_openai_manager().audit_mgr
|
||||
.get_usage(context.get_openai_manager().key_mgr.get_using_key_md5()),
|
||||
'exceeded_keys': context.get_openai_manager().key_mgr.exceeded,
|
||||
}
|
||||
event = plugin_host.emit(plugin_models.KeyExceeded, **args)
|
||||
|
||||
if not event.is_prevented_default():
|
||||
switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch()
|
||||
switched, name = context.get_openai_manager().key_mgr.auto_switch()
|
||||
|
||||
if not switched:
|
||||
reply = handle_exception(
|
||||
"api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key;如果你认为这是误判,请尝试重启程序。".format(
|
||||
current_key_name), "[bot]err:API调用额度超额,请联系管理员,或等待修复")
|
||||
else:
|
||||
openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key()
|
||||
openai.api_key = context.get_openai_manager().key_mgr.get_using_key()
|
||||
mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name))
|
||||
reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"]
|
||||
continue
|
||||
|
|
|
@ -5,28 +5,22 @@ import time
|
|||
import mirai
|
||||
import logging
|
||||
|
||||
from mirai import MessageChain, Plain
|
||||
|
||||
# 这里不使用动态引入config
|
||||
# 因为在这里动态引入会卡死程序
|
||||
# 而此模块静态引用config与动态引入的表现一致
|
||||
# 已弃用,由于超时时间现已动态使用
|
||||
# import config as config_init_import
|
||||
|
||||
import pkg.openai.session
|
||||
import pkg.openai.manager
|
||||
import pkg.utils.reloader
|
||||
import pkg.utils.updater
|
||||
import pkg.utils.context
|
||||
import pkg.qqbot.message
|
||||
import pkg.qqbot.command
|
||||
import pkg.qqbot.ratelimit as ratelimit
|
||||
from ..qqbot import ratelimit
|
||||
from ..qqbot import command, message
|
||||
from ..openai import session as openai_session
|
||||
from ..utils import context
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
import pkg.qqbot.ignore as ignore
|
||||
import pkg.qqbot.banlist as banlist
|
||||
import pkg.qqbot.blob as blob
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
from ..qqbot import ignore
|
||||
from ..qqbot import banlist
|
||||
from ..qqbot import blob
|
||||
import tips as tips_custom
|
||||
|
||||
processing = []
|
||||
|
@ -41,11 +35,11 @@ def is_admin(qq: int) -> bool:
|
|||
return qq == config.admin_qq
|
||||
|
||||
|
||||
def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: MessageChain,
|
||||
sender_id: int) -> MessageChain:
|
||||
def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain,
|
||||
sender_id: int) -> mirai.MessageChain:
|
||||
global processing
|
||||
|
||||
mgr = pkg.utils.context.get_qqbot_manager()
|
||||
mgr = context.get_qqbot_manager()
|
||||
|
||||
reply = []
|
||||
session_name = "{}_{}".format(launcher_type, launcher_id)
|
||||
|
@ -62,7 +56,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
|||
import config
|
||||
|
||||
if not config.wait_last_done and session_name in processing:
|
||||
return MessageChain([Plain(tips_custom.message_drop_tip)])
|
||||
return mirai.MessageChain([mirai.Plain(tips_custom.message_drop_tip)])
|
||||
|
||||
# 检查是否被禁言
|
||||
if launcher_type == 'group':
|
||||
|
@ -74,9 +68,9 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
|||
import config
|
||||
if config.income_msg_check:
|
||||
if mgr.reply_filter.is_illegal(text_message):
|
||||
return MessageChain(Plain("[bot] 消息中存在不合适的内容, 请更换措辞"))
|
||||
return mirai.MessageChain(mirai.Plain("[bot] 消息中存在不合适的内容, 请更换措辞"))
|
||||
|
||||
pkg.openai.session.get_session(session_name).acquire_response_lock()
|
||||
openai_session.get_session(session_name).acquire_response_lock()
|
||||
|
||||
text_message = text_message.strip()
|
||||
|
||||
|
@ -87,7 +81,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
|||
# 处理消息
|
||||
try:
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
|
||||
processing.append(session_name)
|
||||
try:
|
||||
|
@ -114,7 +108,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
|||
reply = event.get_return_value("reply")
|
||||
|
||||
if not event.is_prevented_default():
|
||||
reply = pkg.qqbot.command.process_command(session_name, text_message,
|
||||
reply = command.process_command(session_name, text_message,
|
||||
mgr, config, launcher_type, launcher_id, sender_id, is_admin(sender_id))
|
||||
|
||||
else: # 消息
|
||||
|
@ -124,7 +118,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
|||
if ratelimit.is_reach_limit(session_name):
|
||||
logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message))
|
||||
|
||||
return MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else []
|
||||
return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else []
|
||||
|
||||
before = time.time()
|
||||
# 触发插件事件
|
||||
|
@ -146,7 +140,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
|||
reply = event.get_return_value("reply")
|
||||
|
||||
if not event.is_prevented_default():
|
||||
reply = pkg.qqbot.message.process_normal_message(text_message,
|
||||
reply = message.process_normal_message(text_message,
|
||||
mgr, config, launcher_type, launcher_id, sender_id)
|
||||
|
||||
# 限速等待时间
|
||||
|
@ -170,7 +164,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
|||
finally:
|
||||
processing.remove(session_name)
|
||||
finally:
|
||||
pkg.openai.session.get_session(session_name).release_response_lock()
|
||||
openai_session.get_session(session_name).release_response_lock()
|
||||
|
||||
# 检查延迟时间
|
||||
if config.force_delay_range[1] == 0:
|
||||
|
@ -191,4 +185,4 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
|||
logging.info("[风控] 强制延迟{:.2f}秒(如需关闭,请到config.py修改force_delay_range字段)".format(delay_time))
|
||||
time.sleep(delay_time)
|
||||
|
||||
return MessageChain(reply)
|
||||
return mirai.MessageChain(reply)
|
||||
|
|
|
@ -1,19 +1,18 @@
|
|||
import mirai
|
||||
|
||||
from ..adapter import MessageSourceAdapter, MessageConverter, EventConverter
|
||||
import nakuru
|
||||
import nakuru.entities.components as nkc
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
import traceback
|
||||
import logging
|
||||
import json
|
||||
|
||||
from pkg.qqbot.blob import Forward, ForwardMessageNode, ForwardMessageDiaplay
|
||||
import mirai
|
||||
|
||||
import nakuru
|
||||
import nakuru.entities.components as nkc
|
||||
|
||||
from .. import adapter as adapter_model
|
||||
from ...qqbot import blob
|
||||
|
||||
|
||||
class NakuruProjectMessageConverter(MessageConverter):
|
||||
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
"""消息转换器"""
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: mirai.MessageChain) -> list:
|
||||
|
@ -49,7 +48,7 @@ class NakuruProjectMessageConverter(MessageConverter):
|
|||
nakuru_msg_list.append(nkc.Record.fromURL(component.url))
|
||||
elif component.path is not None:
|
||||
nakuru_msg_list.append(nkc.Record.fromFileSystem(component.path))
|
||||
elif type(component) is Forward:
|
||||
elif type(component) is blob.Forward:
|
||||
# 转发消息
|
||||
yiri_forward_node_list = component.node_list
|
||||
nakuru_forward_node_list = []
|
||||
|
@ -102,7 +101,7 @@ class NakuruProjectMessageConverter(MessageConverter):
|
|||
return chain
|
||||
|
||||
|
||||
class NakuruProjectEventConverter(EventConverter):
|
||||
class NakuruProjectEventConverter(adapter_model.EventConverter):
|
||||
"""事件转换器"""
|
||||
@staticmethod
|
||||
def yiri2target(event: typing.Type[mirai.Event]):
|
||||
|
@ -157,7 +156,7 @@ class NakuruProjectEventConverter(EventConverter):
|
|||
raise Exception("未支持转换的事件类型: " + str(event))
|
||||
|
||||
|
||||
class NakuruProjectAdapter(MessageSourceAdapter):
|
||||
class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
||||
"""nakuru-project适配器"""
|
||||
bot: nakuru.CQHTTP
|
||||
bot_account_id: int
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
from ..adapter import MessageSourceAdapter
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
import mirai.models.bus
|
||||
from mirai.bot import MiraiRunner
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
from .. import adapter as adapter_model
|
||||
|
||||
|
||||
class YiriMiraiAdapter(MessageSourceAdapter):
|
||||
class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
|
||||
"""YiriMirai适配器"""
|
||||
bot: mirai.Mirai
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import threading
|
||||
from pkg.utils import ThreadCtl
|
||||
from . import threadctl
|
||||
|
||||
|
||||
context = {
|
||||
|
@ -87,8 +87,8 @@ def set_thread_ctl(inst):
|
|||
context_lock.release()
|
||||
|
||||
|
||||
def get_thread_ctl() -> ThreadCtl:
|
||||
def get_thread_ctl() -> threadctl.ThreadCtl:
|
||||
context_lock.acquire()
|
||||
t: ThreadCtl = context['pool_ctl']
|
||||
t: threadctl.ThreadCtl = context['pool_ctl']
|
||||
context_lock.release()
|
||||
return t
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from pip._internal import main as pipmain
|
||||
|
||||
import pkg.utils.log as log
|
||||
from . import log
|
||||
|
||||
|
||||
def install(package):
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import logging
|
||||
import threading
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
import pkg.utils.context as context
|
||||
import pkg.plugin.host
|
||||
|
||||
from . import context
|
||||
from ..plugin import host as plugin_host
|
||||
|
||||
|
||||
def walk(module, prefix='', path_prefix=''):
|
||||
|
@ -15,7 +14,7 @@ def walk(module, prefix='', path_prefix=''):
|
|||
walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.', path_prefix + item.name + '/')
|
||||
else:
|
||||
logging.info('reload module: {}, path: {}'.format(prefix + item.name, path_prefix + item.name + '.py'))
|
||||
pkg.plugin.host.__current_module_path__ = "plugins/" + path_prefix + item.name + '.py'
|
||||
plugin_host.__current_module_path__ = "plugins/" + path_prefix + item.name + '.py'
|
||||
importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=['']))
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import logging
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import re
|
||||
import os
|
||||
import config
|
||||
import traceback
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
text_render_font: ImageFont = None
|
||||
|
||||
if config.blob_message_strategy == "image": # 仅在启用了image时才加载字体
|
||||
|
|
|
@ -3,10 +3,9 @@ import logging
|
|||
import os.path
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
import pkg.utils.constants
|
||||
import pkg.utils.network as network
|
||||
from . import constants
|
||||
from . import network
|
||||
|
||||
|
||||
def check_dulwich_closure():
|
||||
|
@ -70,7 +69,7 @@ def get_release_list() -> list:
|
|||
|
||||
def get_current_tag() -> str:
|
||||
"""获取当前tag"""
|
||||
current_tag = pkg.utils.constants.semantic_version
|
||||
current_tag = constants.semantic_version
|
||||
if os.path.exists("current_tag"):
|
||||
with open("current_tag", "r") as f:
|
||||
current_tag = f.read()
|
||||
|
|
Loading…
Reference in New Issue
Block a user