feat: 内容过滤器的可扩展性

This commit is contained in:
RockChinQ 2024-03-08 20:22:06 +08:00
parent 7f554fd862
commit 22cb8a6a06
9 changed files with 53 additions and 26 deletions

View File

@ -30,7 +30,7 @@ def operator_class(
parent_class (typing.Type[CommandOperator], optional): 父节点若为None则为顶级命令. Defaults to None.
Returns:
typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 注册后的命令类
typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器
"""
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:

View File

@ -7,7 +7,7 @@ from ...core import app
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter, entities as filter_entities
from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
@ -16,20 +16,29 @@ from .filters import cntignore, banwords, baiduexamine
class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段"""
filter_chain: list[filter.ContentFilter]
filter_chain: list[filter_model.ContentFilter]
def __init__(self, ap: app.Application):
self.filter_chain = []
super().__init__(ap)
async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
filters_required = [
"ContentIgnore"
]
if self.ap.pipeline_cfg.data['check-sensitive-words']:
self.filter_chain.append(banwords.BanWordFilter(self.ap))
filters_required.append("BanWordFilter")
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
filters_required.append("BaiduCloudExamine")
for filter in filter_model.preregistered_filters:
if filter.name in filters_required:
self.filter_chain.append(
filter(self.ap)
)
for filter in self.filter_chain:
await filter.initialize()

View File

@ -1,12 +1,42 @@
# 内容过滤器的抽象类
from __future__ import annotations
import abc
import typing
from ...core import app
from . import entities
preregistered_filters: list[typing.Type[ContentFilter]] = []
def filter_class(
name: str
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
"""内容过滤器类装饰器
Args:
name (str): 过滤器名称
Returns:
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
"""
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
assert issubclass(cls, ContentFilter)
cls.name = name
preregistered_filters.append(cls)
return cls
return decorator
class ContentFilter(metaclass=abc.ABCMeta):
"""内容过滤器抽象类"""
name: str
ap: app.Application

View File

@ -10,6 +10,7 @@ BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
@filter_model.filter_class("BaiduCloudExamine")
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""

View File

@ -6,6 +6,7 @@ from .. import entities
from ....config import manager as cfg_mgr
@filter_model.filter_class("BanWordFilter")
class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言"""

View File

@ -5,6 +5,7 @@ from .. import entities
from .. import filter as filter_model
@filter_model.filter_class("ContentIgnore")
class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息"""

View File

@ -163,25 +163,6 @@ class PlatformManager:
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False
)
# 通知系统管理员
# TODO delete
# async def notify_admin(self, message: str):
# await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
# async def notify_admin_message_chain(self, message: mirai.MessageChain):
# if self.ap.system_cfg.data['admin-sessions'] != []:
# admin_list = []
# for admin in self.ap.system_cfg.data['admin-sessions']:
# admin_list.append(admin)
# for adm in admin_list:
# self.adapter.send_message(
# adm.split("_")[0],
# adm.split("_")[1],
# message
# )
async def run(self):
try:
tasks = []

View File

@ -24,6 +24,8 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
msg_list = message_chain.__root__
elif type(message_chain) is list:
msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(message_chain)]
else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))

View File

@ -89,6 +89,8 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
msg_list = message_chain.__root__
elif type(message_chain) is list:
msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(text=message_chain)]
else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))