mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
feat: 内容过滤器的可扩展性
This commit is contained in:
parent
7f554fd862
commit
22cb8a6a06
|
@ -30,7 +30,7 @@ def operator_class(
|
|||
parent_class (typing.Type[CommandOperator], optional): 父节点,若为None则为顶级命令. Defaults to None.
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 注册后的命令类
|
||||
typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器
|
||||
"""
|
||||
|
||||
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
|
||||
|
|
|
@ -7,7 +7,7 @@ from ...core import app
|
|||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from . import filter, entities as filter_entities
|
||||
from . import filter as filter_model, entities as filter_entities
|
||||
from .filters import cntignore, banwords, baiduexamine
|
||||
|
||||
|
||||
|
@ -16,20 +16,29 @@ from .filters import cntignore, banwords, baiduexamine
|
|||
class ContentFilterStage(stage.PipelineStage):
|
||||
"""内容过滤阶段"""
|
||||
|
||||
filter_chain: list[filter.ContentFilter]
|
||||
filter_chain: list[filter_model.ContentFilter]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.filter_chain = []
|
||||
super().__init__(ap)
|
||||
|
||||
async def initialize(self):
|
||||
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
|
||||
|
||||
filters_required = [
|
||||
"ContentIgnore"
|
||||
]
|
||||
|
||||
if self.ap.pipeline_cfg.data['check-sensitive-words']:
|
||||
self.filter_chain.append(banwords.BanWordFilter(self.ap))
|
||||
|
||||
filters_required.append("BanWordFilter")
|
||||
|
||||
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
|
||||
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
|
||||
filters_required.append("BaiduCloudExamine")
|
||||
|
||||
for filter in filter_model.preregistered_filters:
|
||||
if filter.name in filters_required:
|
||||
self.filter_chain.append(
|
||||
filter(self.ap)
|
||||
)
|
||||
|
||||
for filter in self.filter_chain:
|
||||
await filter.initialize()
|
||||
|
|
|
@ -1,12 +1,42 @@
|
|||
# 内容过滤器的抽象类
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
||||
|
||||
|
||||
def filter_class(
|
||||
name: str
|
||||
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
|
||||
"""内容过滤器类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 过滤器名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
|
||||
"""
|
||||
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
|
||||
assert issubclass(cls, ContentFilter)
|
||||
|
||||
cls.name = name
|
||||
|
||||
preregistered_filters.append(cls)
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ContentFilter(metaclass=abc.ABCMeta):
|
||||
"""内容过滤器抽象类"""
|
||||
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v
|
|||
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
|
||||
|
||||
@filter_model.filter_class("BaiduCloudExamine")
|
||||
class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
"""百度云内容审核"""
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from .. import entities
|
|||
from ....config import manager as cfg_mgr
|
||||
|
||||
|
||||
@filter_model.filter_class("BanWordFilter")
|
||||
class BanWordFilter(filter_model.ContentFilter):
|
||||
"""根据内容禁言"""
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from .. import entities
|
|||
from .. import filter as filter_model
|
||||
|
||||
|
||||
@filter_model.filter_class("ContentIgnore")
|
||||
class ContentIgnore(filter_model.ContentFilter):
|
||||
"""根据内容忽略消息"""
|
||||
|
||||
|
|
|
@ -163,25 +163,6 @@ class PlatformManager:
|
|||
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False
|
||||
)
|
||||
|
||||
# 通知系统管理员
|
||||
# TODO delete
|
||||
# async def notify_admin(self, message: str):
|
||||
# await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
|
||||
|
||||
# async def notify_admin_message_chain(self, message: mirai.MessageChain):
|
||||
# if self.ap.system_cfg.data['admin-sessions'] != []:
|
||||
|
||||
# admin_list = []
|
||||
# for admin in self.ap.system_cfg.data['admin-sessions']:
|
||||
# admin_list.append(admin)
|
||||
|
||||
# for adm in admin_list:
|
||||
# self.adapter.send_message(
|
||||
# adm.split("_")[0],
|
||||
# adm.split("_")[1],
|
||||
# message
|
||||
# )
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
tasks = []
|
||||
|
|
|
@ -24,6 +24,8 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
|||
msg_list = message_chain.__root__
|
||||
elif type(message_chain) is list:
|
||||
msg_list = message_chain
|
||||
elif type(message_chain) is str:
|
||||
msg_list = [mirai.Plain(message_chain)]
|
||||
else:
|
||||
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
|
||||
|
||||
|
|
|
@ -89,6 +89,8 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
|||
msg_list = message_chain.__root__
|
||||
elif type(message_chain) is list:
|
||||
msg_list = message_chain
|
||||
elif type(message_chain) is str:
|
||||
msg_list = [mirai.Plain(text=message_chain)]
|
||||
else:
|
||||
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user