Merge pull request #890 from RockChinQ/feat/more-platforms

Refactor: 移除 YiriMirai 组件
This commit is contained in:
Junyan Qin 2024-09-26 14:41:03 +08:00 committed by GitHub
commit ea6a0af5a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 1576 additions and 388 deletions

View File

@ -8,10 +8,10 @@ body:
label: 消息平台适配器 label: 消息平台适配器
description: "连接QQ使用的框架" description: "连接QQ使用的框架"
options: options:
- yiri-miraiMirai
- Nakurugo-cqhttp - Nakurugo-cqhttp
- aiocqhttp使用 OneBot 协议接入的) - aiocqhttp使用 OneBot 协议接入的)
- qq-botpyQQ官方API - qq-botpyQQ官方API
- yiri-miraiMirai
validations: validations:
required: false required: false
- type: input - type: input

View File

@ -10,5 +10,4 @@ updates:
schedule: schedule:
interval: "weekly" interval: "weekly"
allow: allow:
- dependency-name: "yiri-mirai-rc"
- dependency-name: "openai" - dependency-name: "openai"

View File

@ -6,8 +6,11 @@
### PR 作者完成 ### PR 作者完成
*请在方括号间写`x`以打勾
- [ ] 阅读仓库[贡献指引](https://github.com/RockChinQ/QChatGPT/blob/master/CONTRIBUTING.md)了吗? - [ ] 阅读仓库[贡献指引](https://github.com/RockChinQ/QChatGPT/blob/master/CONTRIBUTING.md)了吗?
- [ ] 与项目所有者沟通过了吗? - [ ] 与项目所有者沟通过了吗?
- [ ] 我确定已自行测试所作的更改,确保功能符合预期。
### 项目所有者完成 ### 项目所有者完成

View File

@ -3,10 +3,10 @@ from __future__ import annotations
import typing import typing
import pydantic import pydantic
import mirai
from ..core import app, entities as core_entities from ..core import app, entities as core_entities
from . import errors, operator from . import errors, operator
from ..platform.types import message as platform_message
class CommandReturn(pydantic.BaseModel): class CommandReturn(pydantic.BaseModel):
@ -17,7 +17,7 @@ class CommandReturn(pydantic.BaseModel):
"""文本 """文本
""" """
image: typing.Optional[mirai.Image] = None image: typing.Optional[platform_message.Image] = None
"""弃用""" """弃用"""
image_url: typing.Optional[str] = None image_url: typing.Optional[str] = None

View File

@ -5,7 +5,6 @@ required_deps = {
"openai": "openai", "openai": "openai",
"anthropic": "anthropic", "anthropic": "anthropic",
"colorlog": "colorlog", "colorlog": "colorlog",
"mirai": "yiri-mirai-rc",
"aiocqhttp": "aiocqhttp", "aiocqhttp": "aiocqhttp",
"botpy": "qq-botpy", "botpy": "qq-botpy",
"PIL": "pillow", "PIL": "pillow",

View File

@ -6,13 +6,15 @@ import datetime
import asyncio import asyncio
import pydantic import pydantic
import mirai
from ..provider import entities as llm_entities from ..provider import entities as llm_entities
from ..provider.modelmgr import entities from ..provider.modelmgr import entities
from ..provider.sysprompt import entities as sysprompt_entities from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
from ..platform.types import entities as platform_entities
class LauncherTypes(enum.Enum): class LauncherTypes(enum.Enum):
@ -40,10 +42,10 @@ class Query(pydantic.BaseModel):
sender_id: int sender_id: int
"""发送者IDplatform处理阶段设置""" """发送者IDplatform处理阶段设置"""
message_event: mirai.MessageEvent message_event: platform_events.MessageEvent
"""事件platform收到的原始事件""" """事件platform收到的原始事件"""
message_chain: mirai.MessageChain message_chain: platform_message.MessageChain
"""消息链platform收到的原始消息链""" """消息链platform收到的原始消息链"""
adapter: msadapter.MessageSourceAdapter adapter: msadapter.MessageSourceAdapter
@ -67,10 +69,10 @@ class Query(pydantic.BaseModel):
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置""" """使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[mirai.MessageChain]] = [] resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] = []
"""由Process阶段生成的回复消息对象列表""" """由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
"""回复消息链从resp_messages包装而得""" """回复消息链从resp_messages包装而得"""
# ======= 内部保留 ======= # ======= 内部保留 =======
@ -108,7 +110,7 @@ class Session(pydantic.BaseModel):
using_conversation: typing.Optional[Conversation] = None using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = [] conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)

View File

@ -1,9 +1,5 @@
from __future__ import annotations from __future__ import annotations
import mirai
import mirai.models
import mirai.models.message
from ...core import app from ...core import app
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
@ -12,6 +8,9 @@ from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine from .filters import cntignore, banwords, baiduexamine
from ...provider import entities as llm_entities from ...provider import entities as llm_entities
from ...platform.types import message as platform_message
from ...platform.types import events as platform_events
from ...platform.types import entities as platform_entities
@stage.stage_class('PostContentFilterStage') @stage.stage_class('PostContentFilterStage')
@ -89,8 +88,8 @@ class ContentFilterStage(stage.PipelineStage):
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement message = result.replacement
query.message_chain = mirai.MessageChain( query.message_chain = platform_message.MessageChain(
mirai.Plain(message) platform_message.Plain(message)
) )
return entities.StageProcessResult( return entities.StageProcessResult(
@ -148,7 +147,7 @@ class ContentFilterStage(stage.PipelineStage):
contain_non_text = False contain_non_text = False
text_components = [mirai.Plain, mirai.models.message.Source] text_components = [platform_message.Plain, platform_message.Source]
for me in query.message_chain: for me in query.message_chain:
if type(me) not in text_components: if type(me) not in text_components:

View File

@ -4,11 +4,11 @@ import asyncio
import typing import typing
import traceback import traceback
import mirai
from ..core import app, entities from ..core import app, entities
from . import entities as pipeline_entities from . import entities as pipeline_entities
from ..plugin import events from ..plugin import events
from ..platform.types import message as platform_message
class Controller: class Controller:
@ -73,11 +73,11 @@ class Controller:
# 处理str类型 # 处理str类型
if isinstance(result.user_notice, str): if isinstance(result.user_notice, str):
result.user_notice = mirai.MessageChain( result.user_notice = platform_message.MessageChain(
mirai.Plain(result.user_notice) platform_message.Plain(result.user_notice)
) )
elif isinstance(result.user_notice, list): elif isinstance(result.user_notice, list):
result.user_notice = mirai.MessageChain( result.user_notice = platform_message.MessageChain(
*result.user_notice *result.user_notice
) )

View File

@ -4,8 +4,7 @@ import enum
import typing import typing
import pydantic import pydantic
import mirai from ..platform.types import message as platform_message
import mirai.models.message as mirai_message
from ..core import entities from ..core import entities
@ -25,13 +24,9 @@ class StageProcessResult(pydantic.BaseModel):
new_query: entities.Query new_query: entities.Query
user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = [] user_notice: typing.Optional[typing.Union[str, list[platform_message.MessageComponent], platform_message.MessageChain, None]] = []
"""只要设置了就会发送给用户""" """只要设置了就会发送给用户"""
# TODO delete
# admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
"""只要设置了就会发送给管理员"""
console_notice: typing.Optional[str] = '' console_notice: typing.Optional[str] = ''
"""只要设置了就会输出到控制台""" """只要设置了就会输出到控制台"""

View File

@ -3,7 +3,6 @@ import os
import traceback import traceback
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from mirai.models.message import MessageComponent, Plain, MessageChain
from ...core import app from ...core import app
from . import strategy from . import strategy
@ -11,6 +10,7 @@ from .strategies import image, forward
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...config import manager as cfg_mgr from ...config import manager as cfg_mgr
from ...platform.types import message as platform_message
@stage.stage_class("LongTextProcessStage") @stage.stage_class("LongTextProcessStage")
@ -63,14 +63,14 @@ class LongTextProcessStage(stage.PipelineStage):
contains_non_plain = False contains_non_plain = False
for msg in query.resp_message_chain[-1]: for msg in query.resp_message_chain[-1]:
if not isinstance(msg, Plain): if not isinstance(msg, platform_message.Plain):
contains_non_plain = True contains_non_plain = True
break break
if contains_non_plain: if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。") self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']: elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)) query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,

View File

@ -2,15 +2,14 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
from mirai.models import MessageChain import pydantic
from mirai.models.message import MessageComponent, ForwardMessageNode
from mirai.models.base import MiraiBaseModel
from .. import strategy as strategy_model from .. import strategy as strategy_model
from ....core import entities as core_entities from ....core import entities as core_entities
from ....platform.types import message as platform_message
class ForwardMessageDiaplay(MiraiBaseModel): class ForwardMessageDiaplay(pydantic.BaseModel):
title: str = "群聊的聊天记录" title: str = "群聊的聊天记录"
brief: str = "[聊天记录]" brief: str = "[聊天记录]"
source: str = "聊天记录" source: str = "聊天记录"
@ -18,13 +17,13 @@ class ForwardMessageDiaplay(MiraiBaseModel):
summary: str = "查看x条转发消息" summary: str = "查看x条转发消息"
class Forward(MessageComponent): class Forward(platform_message.MessageComponent):
"""合并转发。""" """合并转发。"""
type: str = "Forward" type: str = "Forward"
"""消息组件类型。""" """消息组件类型。"""
display: ForwardMessageDiaplay display: ForwardMessageDiaplay
"""显示信息""" """显示信息"""
node_list: typing.List[ForwardMessageNode] node_list: typing.List[platform_message.ForwardMessageNode]
"""转发消息节点列表。""" """转发消息节点列表。"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if len(args) == 1: if len(args) == 1:
@ -39,7 +38,7 @@ class Forward(MessageComponent):
@strategy_model.strategy_class("forward") @strategy_model.strategy_class("forward")
class ForwardComponentStrategy(strategy_model.LongTextStrategy): class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay( display = ForwardMessageDiaplay(
title="群聊的聊天记录", title="群聊的聊天记录",
brief="[聊天记录]", brief="[聊天记录]",
@ -49,10 +48,10 @@ class ForwardComponentStrategy(strategy_model.LongTextStrategy):
) )
node_list = [ node_list = [
ForwardMessageNode( platform_message.ForwardMessageNode(
sender_id=query.adapter.bot_account_id, sender_id=query.adapter.bot_account_id,
sender_name='QQ用户', sender_name='QQ用户',
message_chain=MessageChain([message]) message_chain=platform_message.MessageChain([message])
) )
] ]

View File

@ -8,8 +8,7 @@ import re
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from mirai.models import MessageChain, Image as ImageComponent from ....platform.types import message as platform_message
from mirai.models.message import MessageComponent
from .. import strategy as strategy_model from .. import strategy as strategy_model
from ....core import entities as core_entities from ....core import entities as core_entities
@ -23,7 +22,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
async def initialize(self): async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8") self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image( img_path = self.text_to_image(
text_str=message, text_str=message,
save_as='temp/{}.png'.format(int(time.time())) save_as='temp/{}.png'.format(int(time.time()))
@ -46,7 +45,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
os.remove(compressed_path) os.remove(compressed_path)
return [ return [
ImageComponent( platform_message.Image(
base64=b64.decode('utf-8'), base64=b64.decode('utf-8'),
) )
] ]

View File

@ -2,11 +2,10 @@ from __future__ import annotations
import abc import abc
import typing import typing
import mirai
from mirai.models.message import MessageComponent
from ...core import app from ...core import app
from ...core import entities as core_entities from ...core import entities as core_entities
from ...platform.types import message as platform_message
preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
@ -51,7 +50,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
"""处理长文本 """处理长文本
platform.json 中配置 long-text-process 字段只要 文本长度超过了 threshold 就会调用此方法 platform.json 中配置 long-text-process 字段只要 文本长度超过了 threshold 就会调用此方法
@ -61,6 +60,6 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
query (core_entities.Query): 此次请求的上下文对象 query (core_entities.Query): 此次请求的上下文对象
Returns: Returns:
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表 list[platform_message.MessageComponent]: 转换后的 平台 消息组件列表
""" """
return [] return []

View File

@ -2,10 +2,11 @@ from __future__ import annotations
import asyncio import asyncio
import mirai
from ..core import entities from ..core import entities
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
class QueryPool: class QueryPool:
@ -30,8 +31,8 @@ class QueryPool:
launcher_type: entities.LauncherTypes, launcher_type: entities.LauncherTypes,
launcher_id: int, launcher_id: int,
sender_id: int, sender_id: int,
message_event: mirai.MessageEvent, message_event: platform_events.MessageEvent,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
adapter: msadapter.MessageSourceAdapter adapter: msadapter.MessageSourceAdapter
) -> entities.Query: ) -> entities.Query:
async with self.condition: async with self.condition:

View File

@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
import mirai
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...provider import entities as llm_entities from ...provider import entities as llm_entities
from ...plugin import events from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class("PreProcessor") @stage.stage_class("PreProcessor")
@ -55,11 +55,11 @@ class PreProcessor(stage.PipelineStage):
content_list = [] content_list = []
for me in query.message_chain: for me in query.message_chain:
if isinstance(me, mirai.Plain): if isinstance(me, platform_message.Plain):
content_list.append( content_list.append(
llm_entities.ContentElement.from_text(me.text) llm_entities.ContentElement.from_text(me.text)
) )
elif isinstance(me, mirai.Image): elif isinstance(me, platform_message.Image):
if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported: if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported:
if me.url is not None: if me.url is not None:
content_list.append( content_list.append(

View File

@ -5,7 +5,6 @@ import time
import traceback import traceback
import json import json
import mirai
from .. import handler from .. import handler
from ... import entities from ... import entities
@ -13,6 +12,8 @@ from ....core import entities as core_entities
from ....provider import entities as llm_entities, runnermgr from ....provider import entities as llm_entities, runnermgr
from ....plugin import events from ....plugin import events
from ....platform.types import message as platform_message
class ChatMessageHandler(handler.MessageHandler): class ChatMessageHandler(handler.MessageHandler):
@ -40,7 +41,7 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply) mc = platform_message.MessageChain(event_ctx.event.reply)
query.resp_messages.append(mc) query.resp_messages.append(mc)

View File

@ -1,13 +1,13 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
import mirai
from .. import handler from .. import handler
from ... import entities from ... import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....provider import entities as llm_entities from ....provider import entities as llm_entities
from ....plugin import events from ....plugin import events
from ....platform.types import message as platform_message
class CommandHandler(handler.MessageHandler): class CommandHandler(handler.MessageHandler):
@ -46,7 +46,7 @@ class CommandHandler(handler.MessageHandler):
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply) mc = platform_message.MessageChain(event_ctx.event.reply)
query.resp_messages.append(mc) query.resp_messages.append(mc)
@ -63,8 +63,8 @@ class CommandHandler(handler.MessageHandler):
else: else:
if event_ctx.event.alter is not None: if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([ query.message_chain = platform_message.MessageChain([
mirai.Plain(event_ctx.event.alter) platform_message.Plain(event_ctx.event.alter)
]) ])
session = await self.ap.sess_mgr.get_session(query) session = await self.ap.sess_mgr.get_session(query)

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import random import random
import asyncio import asyncio
import mirai
from ...core import app from ...core import app

View File

@ -1,9 +1,10 @@
import pydantic import pydantic
import mirai
from ...platform.types import message as platform_message
class RuleJudgeResult(pydantic.BaseModel): class RuleJudgeResult(pydantic.BaseModel):
matching: bool = False matching: bool = False
replacement: mirai.MessageChain = None replacement: platform_message.MessageChain = None

View File

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import mirai
from ...core import app from ...core import app
from . import entities as rule_entities, rule from . import entities as rule_entities, rule

View File

@ -2,11 +2,11 @@ from __future__ import annotations
import abc import abc
import typing import typing
import mirai
from ...core import app, entities as core_entities from ...core import app, entities as core_entities
from . import entities from . import entities
from ...platform.types import message as platform_message
preregisetered_rules: list[typing.Type[GroupRespondRule]] = [] preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
@ -35,7 +35,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
async def match( async def match(
self, self,
message_text: str, message_text: str,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query query: core_entities.Query
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:

View File

@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
import mirai
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("at-bot") @rule_model.rule_class("at-bot")
@ -13,16 +13,16 @@ class AtBotRule(rule_model.GroupRespondRule):
async def match( async def match(
self, self,
message_text: str, message_text: str,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query query: core_entities.Query
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']: if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(query.adapter.bot_account_id)) message_chain.remove(platform_message.At(query.adapter.bot_account_id))
if message_chain.has(mirai.At(query.adapter.bot_account_id)): # 回复消息时会at两次检查并删除重复的 if message_chain.has(platform_message.At(query.adapter.bot_account_id)): # 回复消息时会at两次检查并删除重复的
message_chain.remove(mirai.At(query.adapter.bot_account_id)) message_chain.remove(platform_message.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult( return entities.RuleJudgeResult(
matching=True, matching=True,

View File

@ -1,8 +1,8 @@
import mirai
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("prefix") @rule_model.rule_class("prefix")
@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule):
async def match( async def match(
self, self,
message_text: str, message_text: str,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query query: core_entities.Query
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
@ -22,7 +22,7 @@ class PrefixRule(rule_model.GroupRespondRule):
# 查找第一个plain元素 # 查找第一个plain元素
for me in message_chain: for me in message_chain:
if isinstance(me, mirai.Plain): if isinstance(me, platform_message.Plain):
me.text = me.text[len(prefix):] me.text = me.text[len(prefix):]
return entities.RuleJudgeResult( return entities.RuleJudgeResult(

View File

@ -1,10 +1,10 @@
import random import random
import mirai
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("random") @rule_model.rule_class("random")
@ -13,7 +13,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
async def match( async def match(
self, self,
message_text: str, message_text: str,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query query: core_entities.Query
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:

View File

@ -1,10 +1,10 @@
import re import re
import mirai
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("regexp") @rule_model.rule_class("regexp")
@ -13,7 +13,7 @@ class RegExpRule(rule_model.GroupRespondRule):
async def match( async def match(
self, self,
message_text: str, message_text: str,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query query: core_entities.Query
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import typing import typing
import mirai
from ...core import app, entities as core_entities from ...core import app, entities as core_entities
from .. import entities from .. import entities
@ -10,6 +9,7 @@ from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...config import manager as cfg_mgr from ...config import manager as cfg_mgr
from ...plugin import events from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class("ResponseWrapper") @stage.stage_class("ResponseWrapper")
@ -34,7 +34,7 @@ class ResponseWrapper(stage.PipelineStage):
""" """
# 如果 resp_messages[-1] 已经是 MessageChain 了 # 如果 resp_messages[-1] 已经是 MessageChain 了
if isinstance(query.resp_messages[-1], mirai.MessageChain): if isinstance(query.resp_messages[-1], platform_message.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1]) query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult( yield entities.StageProcessResult(
@ -45,19 +45,14 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if query.resp_messages[-1].role == 'command': if query.resp_messages[-1].role == 'command':
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content)) query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] '))
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
elif query.resp_messages[-1].role == 'plugin': elif query.resp_messages[-1].role == 'plugin':
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain())
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
# else:
# query.resp_message_chain.append(query.resp_messages[-1].content)
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@ -72,7 +67,7 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = '' reply_text = ''
if result.content: # 有内容 if result.content: # 有内容
reply_text = str(result.get_content_mirai_message_chain()) reply_text = str(result.get_content_platform_message_chain())
# ============= 触发插件事件 =============== # ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
@ -96,11 +91,11 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
else: else:
query.resp_message_chain.append(result.get_content_mirai_message_chain()) query.resp_message_chain.append(result.get_content_platform_message_chain())
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@ -113,7 +108,7 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = f'调用函数 {".".join(function_names)}...' reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']: if self.ap.platform_cfg.data['track-function-calls']:
@ -139,11 +134,11 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
else: else:
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,

View File

@ -4,9 +4,10 @@ from __future__ import annotations
import typing import typing
import abc import abc
import mirai
from ..core import app from ..core import app
from .types import message as platform_message
from .types import events as platform_events
preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = [] preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = []
@ -55,28 +56,28 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
self, self,
target_type: str, target_type: str,
target_id: str, target_id: str,
message: mirai.MessageChain message: platform_message.MessageChain
): ):
"""主动发送消息 """主动发送消息
Args: Args:
target_type (str): 目标类型`person``group` target_type (str): 目标类型`person``group`
target_id (str): 目标ID target_id (str): 目标ID
message (mirai.MessageChain): YiriMirai库的消息链 message (platform.types.MessageChain): 消息链
""" """
raise NotImplementedError raise NotImplementedError
async def reply_message( async def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: platform_events.MessageEvent,
message: mirai.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False quote_origin: bool = False
): ):
"""回复消息 """回复消息
Args: Args:
message_source (mirai.MessageEvent): YiriMirai消息源事件 message_source (platform.types.MessageEvent): 消息源事件
message (mirai.MessageChain): YiriMirai库的消息链 message (platform.types.MessageChain): 消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False. quote_origin (bool, optional): 是否引用原消息. Defaults to False.
""" """
raise NotImplementedError raise NotImplementedError
@ -87,27 +88,27 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_message.Event],
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None] callback: typing.Callable[[platform_message.Event, MessageSourceAdapter], None]
): ):
"""注册事件监听器 """注册事件监听器
Args: Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型 event_type (typing.Type[platform.types.Event]): 事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数YiriMirai事件 callback (typing.Callable[[platform.types.Event], None]): 回调函数接收一个参数事件
""" """
raise NotImplementedError raise NotImplementedError
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_message.Event],
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None] callback: typing.Callable[[platform_message.Event, MessageSourceAdapter], None]
): ):
"""注销事件监听器 """注销事件监听器
Args: Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型 event_type (typing.Type[platform.types.Event]): 事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数YiriMirai事件 callback (typing.Callable[[platform.types.Event], None]): 回调函数接收一个参数事件
""" """
raise NotImplementedError raise NotImplementedError
@ -127,26 +128,26 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
class MessageConverter: class MessageConverter:
"""消息链转换器基类""" """消息链转换器基类"""
@staticmethod @staticmethod
def yiri2target(message_chain: mirai.MessageChain): def yiri2target(message_chain: platform_message.MessageChain):
"""YiriMirai消息链转换为目标消息链 """源平台消息链转换为目标平台消息链
Args: Args:
message_chain (mirai.MessageChain): YiriMirai消息链 message_chain (platform.types.MessageChain): 源平台消息链
Returns: Returns:
typing.Any: 目标消息链 typing.Any: 目标平台消息链
""" """
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def target2yiri(message_chain: typing.Any) -> mirai.MessageChain: def target2yiri(message_chain: typing.Any) -> platform_message.MessageChain:
"""将目标消息链转换为YiriMirai消息链 """将目标平台消息链转换为源平台消息链
Args: Args:
message_chain (typing.Any): 目标消息链 message_chain (typing.Any): 目标平台消息链
Returns: Returns:
mirai.MessageChain: YiriMirai消息链 platform.types.MessageChain: 源平台消息链
""" """
raise NotImplementedError raise NotImplementedError
@ -155,25 +156,25 @@ class EventConverter:
"""事件转换器基类""" """事件转换器基类"""
@staticmethod @staticmethod
def yiri2target(event: typing.Type[mirai.Event]): def yiri2target(event: typing.Type[platform_message.Event]):
"""YiriMirai事件转换为目标事件 """源平台事件转换为目标平台事件
Args: Args:
event (typing.Type[mirai.Event]): YiriMirai事件 event (typing.Type[platform.types.Event]): 源平台事件
Returns: Returns:
typing.Any: 目标事件 typing.Any: 目标平台事件
""" """
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def target2yiri(event: typing.Any) -> mirai.Event: def target2yiri(event: typing.Any) -> platform_message.Event:
"""将目标事件的调用参数转换为YiriMirai的事件参数对象 """将目标平台事件的调用参数转换为源平台的事件参数对象
Args: Args:
event (typing.Any): 目标事件 event (typing.Any): 目标平台事件
Returns: Returns:
typing.Type[mirai.Event]: YiriMirai事件 typing.Type[platform.types.Event]: 源平台事件
""" """
raise NotImplementedError raise NotImplementedError

View File

@ -2,17 +2,24 @@ from __future__ import annotations
import json import json
import os import os
import sys
import logging import logging
import asyncio import asyncio
import traceback import traceback
from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \ # FriendMessage, Image, MessageChain, Plain
FriendMessage, Image, MessageChain, Plain
import mirai
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
from ..core import app, entities as core_entities from ..core import app, entities as core_entities
from ..plugin import events from ..plugin import events
from .types import message as platform_message
from .types import events as platform_events
from .types import entities as platform_entities
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题
from . import types as mirai
sys.modules['mirai'] = mirai
# 控制QQ消息输入输出的类 # 控制QQ消息输入输出的类
class PlatformManager: class PlatformManager:
@ -32,7 +39,7 @@ class PlatformManager:
from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy
async def on_friend_message(event: FriendMessage, adapter: msadapter.MessageSourceAdapter): async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PersonMessageReceived( event=events.PersonMessageReceived(
@ -55,7 +62,7 @@ class PlatformManager:
adapter=adapter adapter=adapter
) )
async def on_stranger_message(event: StrangerMessage, adapter: msadapter.MessageSourceAdapter): async def on_stranger_message(event: platform_events.StrangerMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PersonMessageReceived( event=events.PersonMessageReceived(
@ -78,7 +85,7 @@ class PlatformManager:
adapter=adapter adapter=adapter
) )
async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSourceAdapter): async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.GroupMessageReceived( event=events.GroupMessageReceived(
@ -127,16 +134,16 @@ class PlatformManager:
if adapter_name == 'yiri-mirai': if adapter_name == 'yiri-mirai':
adapter_inst.register_listener( adapter_inst.register_listener(
StrangerMessage, platform_events.StrangerMessage,
on_stranger_message on_stranger_message
) )
adapter_inst.register_listener( adapter_inst.register_listener(
FriendMessage, platform_events.FriendMessage,
on_friend_message on_friend_message
) )
adapter_inst.register_listener( adapter_inst.register_listener(
GroupMessage, platform_events.GroupMessage,
on_group_message on_group_message
) )
@ -146,13 +153,13 @@ class PlatformManager:
if len(self.adapters) == 0: if len(self.adapters) == 0:
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。') self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter): async def send(self, event: platform_events.MessageEvent, msg: platform_message.MessageChain, adapter: msadapter.MessageSourceAdapter):
if self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage): if self.ap.platform_cfg.data['at-sender'] and isinstance(event, platform_events.GroupMessage):
msg.insert( msg.insert(
0, 0,
At( platform_message.At(
event.sender.id event.sender.id
) )
) )

View File

@ -5,31 +5,32 @@ import traceback
import time import time
import datetime import datetime
import mirai
import mirai.models.message as yiri_message
import aiocqhttp import aiocqhttp
from .. import adapter from .. import adapter
from ...pipeline.longtext.strategies import forward from ...pipeline.longtext.strategies import forward
from ...core import app from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
class AiocqhttpMessageConverter(adapter.MessageConverter): class AiocqhttpMessageConverter(adapter.MessageConverter):
@staticmethod @staticmethod
def yiri2target(message_chain: mirai.MessageChain) -> typing.Tuple[list, int, datetime.datetime]: def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
msg_list = aiocqhttp.Message() msg_list = aiocqhttp.Message()
msg_id = 0 msg_id = 0
msg_time = None msg_time = None
for msg in message_chain: for msg in message_chain:
if type(msg) is mirai.Plain: if type(msg) is platform_message.Plain:
msg_list.append(aiocqhttp.MessageSegment.text(msg.text)) msg_list.append(aiocqhttp.MessageSegment.text(msg.text))
elif type(msg) is yiri_message.Source: elif type(msg) is platform_message.Source:
msg_id = msg.id msg_id = msg.id
msg_time = msg.time msg_time = msg.time
elif type(msg) is mirai.Image: elif type(msg) is platform_message.Image:
arg = '' arg = ''
if msg.base64: if msg.base64:
arg = msg.base64 arg = msg.base64
@ -40,13 +41,11 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif msg.path: elif msg.path:
arg = msg.path arg = msg.path
msg_list.append(aiocqhttp.MessageSegment.image(arg)) msg_list.append(aiocqhttp.MessageSegment.image(arg))
elif type(msg) is mirai.At: elif type(msg) is platform_message.At:
msg_list.append(aiocqhttp.MessageSegment.at(msg.target)) msg_list.append(aiocqhttp.MessageSegment.at(msg.target))
elif type(msg) is mirai.AtAll: elif type(msg) is platform_message.AtAll:
msg_list.append(aiocqhttp.MessageSegment.at("all")) msg_list.append(aiocqhttp.MessageSegment.at("all"))
elif type(msg) is mirai.Face: elif type(msg) is platform_message.Voice:
msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id))
elif type(msg) is mirai.Voice:
arg = '' arg = ''
if msg.base64: if msg.base64:
arg = msg.base64 arg = msg.base64
@ -74,25 +73,25 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
yiri_msg_list = [] yiri_msg_list = []
yiri_msg_list.append( yiri_msg_list.append(
yiri_message.Source(id=message_id, time=datetime.datetime.now()) platform_message.Source(id=message_id, time=datetime.datetime.now())
) )
for msg in message: for msg in message:
if msg.type == "at": if msg.type == "at":
if msg.data["qq"] == "all": if msg.data["qq"] == "all":
yiri_msg_list.append(yiri_message.AtAll()) yiri_msg_list.append(platform_message.AtAll())
else: else:
yiri_msg_list.append( yiri_msg_list.append(
yiri_message.At( platform_message.At(
target=msg.data["qq"], target=msg.data["qq"],
) )
) )
elif msg.type == "text": elif msg.type == "text":
yiri_msg_list.append(yiri_message.Plain(text=msg.data["text"])) yiri_msg_list.append(platform_message.Plain(text=msg.data["text"]))
elif msg.type == "image": elif msg.type == "image":
yiri_msg_list.append(yiri_message.Image(url=msg.data["url"])) yiri_msg_list.append(platform_message.Image(url=msg.data["url"]))
chain = mirai.MessageChain(yiri_msg_list) chain = platform_message.MessageChain(yiri_msg_list)
return chain return chain
@ -100,11 +99,11 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
class AiocqhttpEventConverter(adapter.EventConverter): class AiocqhttpEventConverter(adapter.EventConverter):
@staticmethod @staticmethod
def yiri2target(event: mirai.Event, bot_account_id: int): def yiri2target(event: platform_events.Event, bot_account_id: int):
msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain) msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain)
if type(event) is mirai.GroupMessage: if type(event) is platform_events.GroupMessage:
role = "member" role = "member"
if event.sender.permission == "ADMINISTRATOR": if event.sender.permission == "ADMINISTRATOR":
@ -140,7 +139,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
} }
return aiocqhttp.Event.from_payload(payload) return aiocqhttp.Event.from_payload(payload)
elif type(event) is mirai.FriendMessage: elif type(event) is platform_events.FriendMessage:
payload = { payload = {
"post_type": "message", "post_type": "message",
@ -178,15 +177,15 @@ class AiocqhttpEventConverter(adapter.EventConverter):
permission = "ADMINISTRATOR" permission = "ADMINISTRATOR"
elif event.sender["role"] == "owner": elif event.sender["role"] == "owner":
permission = "OWNER" permission = "OWNER"
converted_event = mirai.GroupMessage( converted_event = platform_events.GroupMessage(
sender=mirai.models.entities.GroupMember( sender=platform_entities.GroupMember(
id=event.sender["user_id"], # message_seq 放哪? id=event.sender["user_id"], # message_seq 放哪?
member_name=event.sender["nickname"], member_name=event.sender["nickname"],
permission=permission, permission=permission,
group=mirai.models.entities.Group( group=platform_entities.Group(
id=event.group_id, id=event.group_id,
name=event.sender["nickname"], name=event.sender["nickname"],
permission=mirai.models.entities.Permission.Member, permission=platform_entities.Permission.Member,
), ),
special_title=event.sender["title"] if "title" in event.sender else "", special_title=event.sender["title"] if "title" in event.sender else "",
join_timestamp=0, join_timestamp=0,
@ -198,8 +197,8 @@ class AiocqhttpEventConverter(adapter.EventConverter):
) )
return converted_event return converted_event
elif event.message_type == "private": elif event.message_type == "private":
return mirai.FriendMessage( return platform_events.FriendMessage(
sender=mirai.models.entities.Friend( sender=platform_entities.Friend(
id=event.sender["user_id"], id=event.sender["user_id"],
nickname=event.sender["nickname"], nickname=event.sender["nickname"],
remark="", remark="",
@ -241,7 +240,7 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
self.bot = aiocqhttp.CQHttp() self.bot = aiocqhttp.CQHttp()
async def send_message( async def send_message(
self, target_type: str, target_id: str, message: mirai.MessageChain self, target_type: str, target_id: str, message: platform_message.MessageChain
): ):
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0] aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
@ -252,8 +251,8 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
async def reply_message( async def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: platform_events.MessageEvent,
message: mirai.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False, quote_origin: bool = False,
): ):
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id) aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
@ -271,8 +270,8 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None], callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None],
): ):
async def on_message(event: aiocqhttp.Event): async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id self.bot_account_id = event.self_id
@ -281,15 +280,15 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
except: except:
traceback.print_exc() traceback.print_exc()
if event_type == mirai.GroupMessage: if event_type == platform_events.GroupMessage:
self.bot.on_message("group")(on_message) self.bot.on_message("group")(on_message)
elif event_type == mirai.FriendMessage: elif event_type == platform_events.FriendMessage:
self.bot.on_message("private")(on_message) self.bot.on_message("private")(on_message)
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None], callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None],
): ):
return super().unregister_listener(event_type, callback) return super().unregister_listener(event_type, callback)

View File

@ -6,26 +6,28 @@ import typing
import traceback import traceback
import logging import logging
import mirai
import nakuru import nakuru
import nakuru.entities.components as nkc import nakuru.entities.components as nkc
from .. import adapter as adapter_model from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward from ...pipeline.longtext.strategies import forward
from ...platform.types import message as platform_message
from ...platform.types import entities as platform_entities
from ...platform.types import events as platform_events
class NakuruProjectMessageConverter(adapter_model.MessageConverter): class NakuruProjectMessageConverter(adapter_model.MessageConverter):
"""消息转换器""" """消息转换器"""
@staticmethod @staticmethod
def yiri2target(message_chain: mirai.MessageChain) -> list: def yiri2target(message_chain: platform_message.MessageChain) -> list:
msg_list = [] msg_list = []
if type(message_chain) is mirai.MessageChain: if type(message_chain) is platform_message.MessageChain:
msg_list = message_chain.__root__ msg_list = message_chain.__root__
elif type(message_chain) is list: elif type(message_chain) is list:
msg_list = message_chain msg_list = message_chain
elif type(message_chain) is str: elif type(message_chain) is str:
msg_list = [mirai.Plain(message_chain)] msg_list = [platform_message.Plain(message_chain)]
else: else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
@ -33,22 +35,20 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
# 遍历并转换 # 遍历并转换
for component in msg_list: for component in msg_list:
if type(component) is mirai.Plain: if type(component) is platform_message.Plain:
nakuru_msg_list.append(nkc.Plain(component.text, False)) nakuru_msg_list.append(nkc.Plain(component.text, False))
elif type(component) is mirai.Image: elif type(component) is platform_message.Image:
if component.url is not None: if component.url is not None:
nakuru_msg_list.append(nkc.Image.fromURL(component.url)) nakuru_msg_list.append(nkc.Image.fromURL(component.url))
elif component.base64 is not None: elif component.base64 is not None:
nakuru_msg_list.append(nkc.Image.fromBase64(component.base64)) nakuru_msg_list.append(nkc.Image.fromBase64(component.base64))
elif component.path is not None: elif component.path is not None:
nakuru_msg_list.append(nkc.Image.fromFileSystem(component.path)) nakuru_msg_list.append(nkc.Image.fromFileSystem(component.path))
elif type(component) is mirai.Face: elif type(component) is platform_message.At:
nakuru_msg_list.append(nkc.Face(id=component.face_id))
elif type(component) is mirai.At:
nakuru_msg_list.append(nkc.At(qq=component.target)) nakuru_msg_list.append(nkc.At(qq=component.target))
elif type(component) is mirai.AtAll: elif type(component) is platform_message.AtAll:
nakuru_msg_list.append(nkc.AtAll()) nakuru_msg_list.append(nkc.AtAll())
elif type(component) is mirai.Voice: elif type(component) is platform_message.Voice:
if component.url is not None: if component.url is not None:
nakuru_msg_list.append(nkc.Record.fromURL(component.url)) nakuru_msg_list.append(nkc.Record.fromURL(component.url))
elif component.path is not None: elif component.path is not None:
@ -80,49 +80,47 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
return nakuru_msg_list return nakuru_msg_list
@staticmethod @staticmethod
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> mirai.MessageChain: def target2yiri(message_chain: typing.Any, message_id: int = -1) -> platform_message.MessageChain:
"""将Yiri的消息链转换为YiriMirai的消息链""" """将Yiri的消息链转换为YiriMirai的消息链"""
assert type(message_chain) is list assert type(message_chain) is list
yiri_msg_list = [] yiri_msg_list = []
import datetime import datetime
# 添加Source组件以标记message_id等信息 # 添加Source组件以标记message_id等信息
yiri_msg_list.append(mirai.models.message.Source(id=message_id, time=datetime.datetime.now())) yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
for component in message_chain: for component in message_chain:
if type(component) is nkc.Plain: if type(component) is nkc.Plain:
yiri_msg_list.append(mirai.Plain(text=component.text)) yiri_msg_list.append(platform_message.Plain(text=component.text))
elif type(component) is nkc.Image: elif type(component) is nkc.Image:
yiri_msg_list.append(mirai.Image(url=component.url)) yiri_msg_list.append(platform_message.Image(url=component.url))
elif type(component) is nkc.Face:
yiri_msg_list.append(mirai.Face(face_id=component.id))
elif type(component) is nkc.At: elif type(component) is nkc.At:
yiri_msg_list.append(mirai.At(target=component.qq)) yiri_msg_list.append(platform_message.At(target=component.qq))
elif type(component) is nkc.AtAll: elif type(component) is nkc.AtAll:
yiri_msg_list.append(mirai.AtAll()) yiri_msg_list.append(platform_message.AtAll())
else: else:
pass pass
# logging.debug("转换后的消息链: " + str(yiri_msg_list)) # logging.debug("转换后的消息链: " + str(yiri_msg_list))
chain = mirai.MessageChain(yiri_msg_list) chain = platform_message.MessageChain(yiri_msg_list)
return chain return chain
class NakuruProjectEventConverter(adapter_model.EventConverter): class NakuruProjectEventConverter(adapter_model.EventConverter):
"""事件转换器""" """事件转换器"""
@staticmethod @staticmethod
def yiri2target(event: typing.Type[mirai.Event]): def yiri2target(event: typing.Type[platform_events.Event]):
if event is mirai.GroupMessage: if event is platform_events.GroupMessage:
return nakuru.GroupMessage return nakuru.GroupMessage
elif event is mirai.FriendMessage: elif event is platform_events.FriendMessage:
return nakuru.FriendMessage return nakuru.FriendMessage
else: else:
raise Exception("未支持转换的事件类型: " + str(event)) raise Exception("未支持转换的事件类型: " + str(event))
@staticmethod @staticmethod
def target2yiri(event: typing.Any) -> mirai.Event: def target2yiri(event: typing.Any) -> platform_events.Event:
yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id) yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id)
if type(event) is nakuru.FriendMessage: # 私聊消息事件 if type(event) is nakuru.FriendMessage: # 私聊消息事件
return mirai.FriendMessage( return platform_events.FriendMessage(
sender=mirai.models.entities.Friend( sender=platform_entities.Friend(
id=event.sender.user_id, id=event.sender.user_id,
nickname=event.sender.nickname, nickname=event.sender.nickname,
remark=event.sender.nickname remark=event.sender.nickname
@ -138,16 +136,15 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
elif event.sender.role == "owner": elif event.sender.role == "owner":
permission = "OWNER" permission = "OWNER"
import mirai.models.entities as entities return platform_events.GroupMessage(
return mirai.GroupMessage( sender=platform_entities.GroupMember(
sender=mirai.models.entities.GroupMember(
id=event.sender.user_id, id=event.sender.user_id,
member_name=event.sender.nickname, member_name=event.sender.nickname,
permission=permission, permission=permission,
group=mirai.models.entities.Group( group=platform_entities.Group(
id=event.group_id, id=event.group_id,
name=event.sender.nickname, name=event.sender.nickname,
permission=entities.Permission.Member permission=platform_entities.Permission.Member
), ),
special_title=event.sender.title, special_title=event.sender.title,
join_timestamp=0, join_timestamp=0,
@ -189,7 +186,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
self, self,
target_type: str, target_type: str,
target_id: str, target_id: str,
message: typing.Union[mirai.MessageChain, list], message: typing.Union[platform_message.MessageChain, list],
converted: bool = False converted: bool = False
): ):
task = None task = None
@ -222,8 +219,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
async def reply_message( async def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: platform_events.MessageEvent,
message: mirai.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False quote_origin: bool = False
): ):
message = self.message_converter.yiri2target(message) message = self.message_converter.yiri2target(message)
@ -233,14 +230,14 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
id=message_source.message_chain.message_id, id=message_source.message_chain.message_id,
) )
) )
if type(message_source) is mirai.GroupMessage: if type(message_source) is platform_events.GroupMessage:
await self.send_message( await self.send_message(
"group", "group",
message_source.sender.group.id, message_source.sender.group.id,
message, message,
converted=True converted=True
) )
elif type(message_source) is mirai.FriendMessage: elif type(message_source) is platform_events.FriendMessage:
await self.send_message( await self.send_message(
"person", "person",
message_source.sender.id, message_source.sender.id,
@ -258,8 +255,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] callback: typing.Callable[[platform_events.Event, adapter_model.MessageSourceAdapter], None]
): ):
try: try:
@ -286,8 +283,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] callback: typing.Callable[[platform_events.Event, adapter_model.MessageSourceAdapter], None]
): ):
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__ nakuru_event_name = self.event_converter.yiri2target(event_type).__name__

View File

@ -6,7 +6,6 @@ import datetime
import re import re
import traceback import traceback
import mirai
import botpy import botpy
import botpy.message as botpy_message import botpy.message as botpy_message
import botpy.types.message as botpy_message_type import botpy.types.message as botpy_message_type
@ -17,17 +16,21 @@ from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward from ...pipeline.longtext.strategies import forward
from ...core import app from ...core import app
from ...config import manager as cfg_mgr from ...config import manager as cfg_mgr
from ...platform.types import entities as platform_entities
from ...platform.types import events as platform_events
from ...platform.types import message as platform_message
class OfficialGroupMessage(mirai.GroupMessage):
class OfficialGroupMessage(platform_events.GroupMessage):
pass pass
class OfficialFriendMessage(mirai.FriendMessage): class OfficialFriendMessage(platform_events.FriendMessage):
pass pass
event_handler_mapping = { event_handler_mapping = {
mirai.GroupMessage: ["on_at_message_create", "on_group_at_message_create"], platform_events.GroupMessage: ["on_at_message_create", "on_group_at_message_create"],
mirai.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"], platform_events.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"],
} }
@ -123,16 +126,16 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
"""QQ 官方消息转换器""" """QQ 官方消息转换器"""
@staticmethod @staticmethod
def yiri2target(message_chain: mirai.MessageChain): def yiri2target(message_chain: platform_message.MessageChain):
"""将 YiriMirai 的消息链转换为 QQ 官方消息""" """将 YiriMirai 的消息链转换为 QQ 官方消息"""
msg_list = [] msg_list = []
if type(message_chain) is mirai.MessageChain: if type(message_chain) is platform_message.MessageChain:
msg_list = message_chain.__root__ msg_list = message_chain.__root__
elif type(message_chain) is list: elif type(message_chain) is list:
msg_list = message_chain msg_list = message_chain
elif type(message_chain) is str: elif type(message_chain) is str:
msg_list = [mirai.Plain(text=message_chain)] msg_list = [platform_message.Plain(text=message_chain)]
else: else:
raise Exception( raise Exception(
"Unknown message type: " + str(message_chain) + str(type(message_chain)) "Unknown message type: " + str(message_chain) + str(type(message_chain))
@ -153,22 +156,22 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
# 遍历并转换 # 遍历并转换
for component in msg_list: for component in msg_list:
if type(component) is mirai.Plain: if type(component) is platform_message.Plain:
offcial_messages.append({"type": "text", "content": component.text}) offcial_messages.append({"type": "text", "content": component.text})
elif type(component) is mirai.Image: elif type(component) is platform_message.Image:
if component.url is not None: if component.url is not None:
offcial_messages.append({"type": "image", "content": component.url}) offcial_messages.append({"type": "image", "content": component.url})
elif component.path is not None: elif component.path is not None:
offcial_messages.append( offcial_messages.append(
{"type": "file_image", "content": component.path} {"type": "file_image", "content": component.path}
) )
elif type(component) is mirai.At: elif type(component) is platform_message.At:
offcial_messages.append({"type": "at", "content": ""}) offcial_messages.append({"type": "at", "content": ""})
elif type(component) is mirai.AtAll: elif type(component) is platform_message.AtAll:
print( print(
"上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。" "上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
) )
elif type(component) is mirai.Voice: elif type(component) is platform_message.Voice:
print( print(
"上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。" "上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
) )
@ -197,29 +200,29 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage], message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
message_id: str = None, message_id: str = None,
bot_account_id: int = 0, bot_account_id: int = 0,
) -> mirai.MessageChain: ) -> platform_message.MessageChain:
yiri_msg_list = [] yiri_msg_list = []
# 存id # 存id
yiri_msg_list.append( yiri_msg_list.append(
mirai.models.message.Source( platform_message.Source(
id=save_msg_id(message_id), time=datetime.datetime.now() id=save_msg_id(message_id), time=datetime.datetime.now()
) )
) )
if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]: if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]:
yiri_msg_list.append(mirai.At(target=bot_account_id)) yiri_msg_list.append(platform_message.At(target=bot_account_id))
if hasattr(message, "mentions"): if hasattr(message, "mentions"):
for mention in message.mentions: for mention in message.mentions:
if mention.bot: if mention.bot:
continue continue
yiri_msg_list.append(mirai.At(target=mention.id)) yiri_msg_list.append(platform_message.At(target=mention.id))
for attachment in message.attachments: for attachment in message.attachments:
if attachment.content_type.startswith("image"): if attachment.content_type.startswith("image"):
yiri_msg_list.append(mirai.Image(url=attachment.url)) yiri_msg_list.append(platform_message.Image(url=attachment.url))
else: else:
logging.warning( logging.warning(
"不支持的附件类型:" + attachment.content_type + ",忽略此附件。" "不支持的附件类型:" + attachment.content_type + ",忽略此附件。"
@ -227,9 +230,9 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
content = re.sub(r"<@!\d+>", "", str(message.content)) content = re.sub(r"<@!\d+>", "", str(message.content))
if content.strip() != "": if content.strip() != "":
yiri_msg_list.append(mirai.Plain(text=content)) yiri_msg_list.append(platform_message.Plain(text=content))
chain = mirai.MessageChain(yiri_msg_list) chain = platform_message.MessageChain(yiri_msg_list)
return chain return chain
@ -244,10 +247,10 @@ class OfficialEventConverter(adapter_model.EventConverter):
self.member_openid_mapping = member_openid_mapping self.member_openid_mapping = member_openid_mapping
self.group_openid_mapping = group_openid_mapping self.group_openid_mapping = group_openid_mapping
def yiri2target(self, event: typing.Type[mirai.Event]): def yiri2target(self, event: typing.Type[platform_events.Event]):
if event == mirai.GroupMessage: if event == platform_events.GroupMessage:
return botpy_message.Message return botpy_message.Message
elif event == mirai.FriendMessage: elif event == platform_events.FriendMessage:
return botpy_message.DirectMessage return botpy_message.DirectMessage
else: else:
raise Exception( raise Exception(
@ -257,8 +260,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
def target2yiri( def target2yiri(
self, self,
event: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage], event: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
) -> mirai.Event: ) -> platform_events.Event:
import mirai.models.entities as mirai_entities
if type(event) == botpy_message.Message: # 频道内,转群聊事件 if type(event) == botpy_message.Message: # 频道内,转群聊事件
permission = "MEMBER" permission = "MEMBER"
@ -268,15 +270,15 @@ class OfficialEventConverter(adapter_model.EventConverter):
elif "4" in event.member.roles: elif "4" in event.member.roles:
permission = "OWNER" permission = "OWNER"
return mirai.GroupMessage( return platform_events.GroupMessage(
sender=mirai_entities.GroupMember( sender=platform_entities.GroupMember(
id=event.author.id, id=event.author.id,
member_name=event.author.username, member_name=event.author.username,
permission=permission, permission=permission,
group=mirai_entities.Group( group=platform_entities.Group(
id=event.channel_id, id=event.channel_id,
name=event.author.username, name=event.author.username,
permission=mirai_entities.Permission.Member, permission=platform_entities.Permission.Member,
), ),
special_title="", special_title="",
join_timestamp=int( join_timestamp=int(
@ -297,8 +299,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
), ),
) )
elif type(event) == botpy_message.DirectMessage: # 频道私聊,转私聊事件 elif type(event) == botpy_message.DirectMessage: # 频道私聊,转私聊事件
return mirai.FriendMessage( return platform_events.FriendMessage(
sender=mirai_entities.Friend( sender=platform_entities.Friend(
id=event.guild_id, id=event.guild_id,
nickname=event.author.username, nickname=event.author.username,
remark=event.author.username, remark=event.author.username,
@ -317,14 +319,14 @@ class OfficialEventConverter(adapter_model.EventConverter):
replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid) replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid)
return OfficialGroupMessage( return OfficialGroupMessage(
sender=mirai_entities.GroupMember( sender=platform_entities.GroupMember(
id=replacing_member_id, id=replacing_member_id,
member_name=replacing_member_id, member_name=replacing_member_id,
permission="MEMBER", permission="MEMBER",
group=mirai_entities.Group( group=platform_entities.Group(
id=self.group_openid_mapping.save_openid(event.group_openid), id=self.group_openid_mapping.save_openid(event.group_openid),
name=replacing_member_id, name=replacing_member_id,
permission=mirai_entities.Permission.Member, permission=platform_entities.Permission.Member,
), ),
special_title="", special_title="",
join_timestamp=int(0), join_timestamp=int(0),
@ -345,7 +347,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
user_id_alter = self.member_openid_mapping.save_openid(event.author.user_openid) # 实测这里的user_openid与group的member_openid是一样的 user_id_alter = self.member_openid_mapping.save_openid(event.author.user_openid) # 实测这里的user_openid与group的member_openid是一样的
return OfficialFriendMessage( return OfficialFriendMessage(
sender=mirai_entities.Friend( sender=platform_entities.Friend(
id=user_id_alter, id=user_id_alter,
nickname=user_id_alter, nickname=user_id_alter,
remark=user_id_alter, remark=user_id_alter,
@ -410,7 +412,7 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
self.bot = botpy.Client(intents=intents) self.bot = botpy.Client(intents=intents)
async def send_message( async def send_message(
self, target_type: str, target_id: str, message: mirai.MessageChain self, target_type: str, target_id: str, message: platform_message.MessageChain
): ):
message_list = self.message_converter.yiri2target(message) message_list = self.message_converter.yiri2target(message)
@ -437,8 +439,8 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
async def reply_message( async def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: platform_events.MessageEvent,
message: mirai.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False, quote_origin: bool = False,
): ):
@ -463,13 +465,13 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
] ]
) )
if type(message_source) == mirai.GroupMessage: if type(message_source) == platform_events.GroupMessage:
args["channel_id"] = str(message_source.sender.group.id) args["channel_id"] = str(message_source.sender.group.id)
args["msg_id"] = cached_message_ids[ args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id) str(message_source.message_chain.message_id)
] ]
await self.bot.api.post_message(**args) await self.bot.api.post_message(**args)
elif type(message_source) == mirai.FriendMessage: elif type(message_source) == platform_events.FriendMessage:
args["guild_id"] = str(message_source.sender.id) args["guild_id"] = str(message_source.sender.id)
args["msg_id"] = cached_message_ids[ args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id) str(message_source.message_chain.message_id)
@ -534,9 +536,9 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[ callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None [platform_events.Event, adapter_model.MessageSourceAdapter], None
], ],
): ):
@ -560,9 +562,9 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[ callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None [platform_events.Event, adapter_model.MessageSourceAdapter], None
], ],
): ):
delattr(self.bot, event_handler_mapping[event_type]) delattr(self.bot, event_handler_mapping[event_type])

View File

@ -1,124 +1,121 @@
import asyncio # import asyncio
import typing # import typing
import mirai
import mirai.models.bus
from mirai.bot import MiraiRunner
from .. import adapter as adapter_model
from ...core import app
@adapter_model.adapter_class("yiri-mirai") # from .. import adapter as adapter_model
class YiriMiraiAdapter(adapter_model.MessageSourceAdapter): # from ...core import app
"""YiriMirai适配器"""
bot: mirai.Mirai
def __init__(self, config: dict, ap: app.Application):
"""初始化YiriMirai的对象"""
self.ap = ap
self.config = config
if 'adapter' not in config or \
config['adapter'] == 'WebSocketAdapter':
self.bot = mirai.Mirai(
qq=config['qq'],
adapter=mirai.WebSocketAdapter(
host=config['host'],
port=config['port'],
verify_key=config['verifyKey']
)
)
elif config['adapter'] == 'HTTPAdapter':
self.bot = mirai.Mirai(
qq=config['qq'],
adapter=mirai.HTTPAdapter(
host=config['host'],
port=config['port'],
verify_key=config['verifyKey']
)
)
else:
raise Exception('Unknown adapter for YiriMirai: ' + config['adapter'])
async def send_message( # @adapter_model.adapter_class("yiri-mirai")
self, # class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
target_type: str, # """YiriMirai适配器"""
target_id: str, # bot: mirai.Mirai
message: mirai.MessageChain
):
"""发送消息
Args: # def __init__(self, config: dict, ap: app.Application):
target_type (str): 目标类型`person``group` # """初始化YiriMirai的对象"""
target_id (str): 目标ID # self.ap = ap
message (mirai.MessageChain): YiriMirai库的消息链 # self.config = config
""" # if 'adapter' not in config or \
task = None # config['adapter'] == 'WebSocketAdapter':
if target_type == 'person': # self.bot = mirai.Mirai(
task = self.bot.send_friend_message(int(target_id), message) # qq=config['qq'],
elif target_type == 'group': # adapter=mirai.WebSocketAdapter(
task = self.bot.send_group_message(int(target_id), message) # host=config['host'],
else: # port=config['port'],
raise Exception('Unknown target type: ' + target_type) # verify_key=config['verifyKey']
# )
# )
# elif config['adapter'] == 'HTTPAdapter':
# self.bot = mirai.Mirai(
# qq=config['qq'],
# adapter=mirai.HTTPAdapter(
# host=config['host'],
# port=config['port'],
# verify_key=config['verifyKey']
# )
# )
# else:
# raise Exception('Unknown adapter for YiriMirai: ' + config['adapter'])
await task # async def send_message(
# self,
# target_type: str,
# target_id: str,
# message: mirai.MessageChain
# ):
# """发送消息
async def reply_message( # Args:
self, # target_type (str): 目标类型,`person`或`group`
message_source: mirai.MessageEvent, # target_id (str): 目标ID
message: mirai.MessageChain, # message (mirai.MessageChain): YiriMirai库的消息链
quote_origin: bool = False # """
): # task = None
"""回复消息 # if target_type == 'person':
# task = self.bot.send_friend_message(int(target_id), message)
# elif target_type == 'group':
# task = self.bot.send_group_message(int(target_id), message)
# else:
# raise Exception('Unknown target type: ' + target_type)
Args: # await task
message_source (mirai.MessageEvent): YiriMirai消息源事件
message (mirai.MessageChain): YiriMirai库的消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
"""
await self.bot.send(message_source, message, quote_origin)
async def is_muted(self, group_id: int) -> bool: # async def reply_message(
result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get() # self,
if result.mute_time_remaining > 0: # message_source: mirai.MessageEvent,
return True # message: mirai.MessageChain,
return False # quote_origin: bool = False
# ):
# """回复消息
def register_listener( # Args:
self, # message_source (mirai.MessageEvent): YiriMirai消息源事件
event_type: typing.Type[mirai.Event], # message (mirai.MessageChain): YiriMirai库的消息链
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] # quote_origin (bool, optional): 是否引用原消息. Defaults to False.
): # """
"""注册事件监听器 # await self.bot.send(message_source, message, quote_origin)
Args: # async def is_muted(self, group_id: int) -> bool:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型 # result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get()
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件 # if result.mute_time_remaining > 0:
""" # return True
async def wrapper(event: mirai.Event): # return False
await callback(event, self)
self.bot.on(event_type)(wrapper)
def unregister_listener( # def register_listener(
self, # self,
event_type: typing.Type[mirai.Event], # event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] # callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
): # ):
"""注销事件监听器 # """注册事件监听器
Args: # Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型 # event_type (typing.Type[mirai.Event]): YiriMirai事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件 # callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
""" # """
assert isinstance(self.bot, mirai.Mirai) # async def wrapper(event: mirai.Event):
bus = self.bot.bus # await callback(event, self)
assert isinstance(bus, mirai.models.bus.ModelEventBus) # self.bot.on(event_type)(wrapper)
bus.unsubscribe(event_type, callback) # def unregister_listener(
# self,
# event_type: typing.Type[mirai.Event],
# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
# ):
# """注销事件监听器
async def run_async(self): # Args:
self.bot_account_id = self.bot.qq # event_type (typing.Type[mirai.Event]): YiriMirai事件类型
return await MiraiRunner(self.bot)._run() # callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
# """
# assert isinstance(self.bot, mirai.Mirai)
# bus = self.bot.bus
# assert isinstance(bus, mirai.models.bus.ModelEventBus)
async def kill(self) -> bool: # bus.unsubscribe(event_type, callback)
return False
# async def run_async(self):
# self.bot_account_id = self.bot.qq
# return await MiraiRunner(self.bot)._run()
# async def kill(self) -> bool:
# return False

View File

@ -0,0 +1,3 @@
from .entities import *
from .events import *
from .message import *

105
pkg/platform/types/base.py Normal file
View File

@ -0,0 +1,105 @@
from typing import Dict, List, Type
import pydantic.main as pdm
from pydantic import BaseModel
class PlatformMetaclass(pdm.ModelMetaclass):
"""此类是平台中使用的 pydantic 模型的元类的基类。"""
def to_camel(name: str) -> str:
"""将下划线命名风格转换为小驼峰命名。"""
if name[:2] == '__': # 不处理双下划线开头的特殊命名。
return name
name_parts = name.split('_')
return ''.join(name_parts[:1] + [x.title() for x in name_parts[1:]])
class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass):
"""模型基类。
启用了三项配置
1. 允许解析时传入额外的值并将额外值保存在模型中
2. 允许通过别名访问字段
3. 自动生成小驼峰风格的别名
"""
def __init__(self, *args, **kwargs):
""""""
super().__init__(*args, **kwargs)
def __repr__(self) -> str:
return self.__class__.__name__ + '(' + ', '.join(
(f'{k}={repr(v)}' for k, v in self.__dict__.items() if v)
) + ')'
class Config:
extra = 'allow'
allow_population_by_field_name = True
alias_generator = to_camel
class PlatformIndexedMetaclass(PlatformMetaclass):
"""可以通过子类名获取子类的类的元类。"""
__indexedbases__: List[Type['PlatformIndexedModel']] = []
__indexedmodel__ = None
def __new__(cls, name, bases, attrs, **kwargs):
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
# 第一类PlatformIndexedModel
if name == 'PlatformIndexedModel':
cls.__indexedmodel__ = new_cls
new_cls.__indexes__ = {}
return new_cls
# 第二类PlatformIndexedModel 的直接子类,这些是可以通过子类名获取子类的类。
if cls.__indexedmodel__ in bases:
cls.__indexedbases__.append(new_cls)
new_cls.__indexes__ = {}
return new_cls
# 第三类PlatformIndexedModel 的直接子类的子类,这些添加到直接子类的索引中。
for base in cls.__indexedbases__:
if issubclass(new_cls, base):
base.__indexes__[name] = new_cls
return new_cls
def __getitem__(cls, name):
return cls.get_subtype(name)
class PlatformIndexedModel(PlatformBaseModel, metaclass=PlatformIndexedMetaclass):
"""可以通过子类名获取子类的类。"""
__indexes__: Dict[str, Type['PlatformIndexedModel']]
@classmethod
def get_subtype(cls, name: str) -> Type['PlatformIndexedModel']:
"""根据类名称,获取相应的子类类型。
Args:
name: 类名称
Returns:
Type['PlatformIndexedModel']: 子类类型
"""
try:
type_ = cls.__indexes__.get(name)
if not (type_ and issubclass(type_, cls)):
raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!')
return type_
except AttributeError as e:
raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!') from None
@classmethod
def parse_subtype(cls, obj: dict) -> 'PlatformIndexedModel':
"""通过字典,构造对应的模型对象。
Args:
obj: 一个字典包含了模型对象的属性
Returns:
PlatformIndexedModel: 构造的对象
"""
if cls in PlatformIndexedModel.__subclasses__():
ModelType = cls.get_subtype(obj['type'])
return ModelType.parse_obj(obj)
return super().parse_obj(obj)

View File

@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
"""
此模块提供实体和配置项模型
"""
import abc
from datetime import datetime
from enum import Enum
import typing
import pydantic
class Entity(pydantic.BaseModel):
"""实体,表示一个用户或群。"""
id: int
"""QQ 号或群号。"""
@abc.abstractmethod
def get_avatar_url(self) -> str:
"""头像图片链接。"""
@abc.abstractmethod
def get_name(self) -> str:
"""名称。"""
class Friend(Entity):
"""好友。"""
id: int
"""QQ 号。"""
nickname: typing.Optional[str]
"""昵称。"""
remark: typing.Optional[str]
"""备注。"""
def get_avatar_url(self) -> str:
return f'http://q4.qlogo.cn/g?b=qq&nk={self.id}&s=140'
def get_name(self) -> str:
return self.nickname or self.remark or ''
class Permission(str, Enum):
"""群成员身份权限。"""
Member = "MEMBER"
"""成员。"""
Administrator = "ADMINISTRATOR"
"""管理员。"""
Owner = "OWNER"
"""群主。"""
def __repr__(self) -> str:
return repr(self.value)
class Group(Entity):
"""群。"""
id: int
"""群号。"""
name: str
"""群名称。"""
permission: Permission
"""Bot 在群中的权限。"""
def get_avatar_url(self) -> str:
return f'https://p.qlogo.cn/gh/{self.id}/{self.id}/'
def get_name(self) -> str:
return self.name
class GroupMember(Entity):
"""群成员。"""
id: int
"""QQ 号。"""
member_name: str
"""群成员名称。"""
permission: Permission
"""Bot 在群中的权限。"""
group: Group
"""群。"""
special_title: str = ''
"""群头衔。"""
join_timestamp: datetime = datetime.utcfromtimestamp(0)
"""加入群的时间。"""
last_speak_timestamp: datetime = datetime.utcfromtimestamp(0)
"""最后一次发言的时间。"""
mute_time_remaining: int = 0
"""禁言剩余时间。"""
def get_avatar_url(self) -> str:
return f'http://q4.qlogo.cn/g?b=qq&nk={self.id}&s=140'
def get_name(self) -> str:
return self.member_name
class Client(Entity):
"""来自其他客户端的用户。"""
id: int
"""识别 id。"""
platform: str
"""来源平台。"""
def get_avatar_url(self) -> str:
raise NotImplementedError
def get_name(self) -> str:
return self.platform
class Subject(pydantic.BaseModel):
"""另一种实体类型表示。"""
id: int
"""QQ 号或群号。"""
kind: typing.Literal['Friend', 'Group', 'Stranger']
"""类型。"""
class Config(pydantic.BaseModel):
"""配置项类型。"""
def modify(self, **kwargs) -> 'Config':
"""修改部分设置。"""
for k, v in kwargs.items():
if k in self.__fields__:
setattr(self, k, v)
else:
raise ValueError(f'未知配置项: {k}')
return self
class GroupConfigModel(Config):
"""群配置。"""
name: str
"""群名称。"""
confess_talk: bool
"""是否允许坦白说。"""
allow_member_invite: bool
"""是否允许成员邀请好友入群。"""
auto_approve: bool
"""是否开启自动审批入群。"""
anonymous_chat: bool
"""是否开启匿名聊天。"""
announcement: str = ''
"""群公告。"""
class MemberInfoModel(Config, GroupMember):
"""群成员信息。"""

View File

@ -0,0 +1,124 @@
# -*- coding: utf-8 -*-
"""
此模块提供事件模型
"""
from datetime import datetime
from enum import Enum
import typing
import pydantic
from . import entities as platform_entities
from . import message as platform_message
class Event(pydantic.BaseModel):
"""事件基类。
Args:
type: 事件名
"""
type: str
"""事件名。"""
def __repr__(self):
return self.__class__.__name__ + '(' + ', '.join(
(
f'{k}={repr(v)}'
for k, v in self.__dict__.items() if k != 'type' and v
)
) + ')'
@classmethod
def parse_subtype(cls, obj: dict) -> 'Event':
try:
return typing.cast(Event, super().parse_subtype(obj))
except ValueError:
return Event(type=obj['type'])
@classmethod
def get_subtype(cls, name: str) -> typing.Type['Event']:
try:
return typing.cast(typing.Type[Event], super().get_subtype(name))
except ValueError:
return Event
###############################
# Bot Event
class BotEvent(Event):
"""Bot 自身事件。
Args:
type: 事件名
qq: Bot QQ
"""
type: str
"""事件名。"""
qq: int
"""Bot 的 QQ 号。"""
###############################
# Message Event
class MessageEvent(Event):
"""消息事件。
Args:
type: 事件名
message_chain: 消息内容
"""
type: str
"""事件名。"""
message_chain: platform_message.MessageChain
"""消息内容。"""
class FriendMessage(MessageEvent):
"""好友消息。
Args:
type: 事件名
sender: 发送消息的好友
message_chain: 消息内容
"""
type: str = 'FriendMessage'
"""事件名。"""
sender: platform_entities.Friend
"""发送消息的好友。"""
message_chain: platform_message.MessageChain
"""消息内容。"""
class GroupMessage(MessageEvent):
"""群消息。
Args:
type: 事件名
sender: 发送消息的群成员
message_chain: 消息内容
"""
type: str = 'GroupMessage'
"""事件名。"""
sender: platform_entities.GroupMember
"""发送消息的群成员。"""
message_chain: platform_message.MessageChain
"""消息内容。"""
@property
def group(self) -> platform_entities.Group:
return self.sender.group
class StrangerMessage(MessageEvent):
"""陌生人消息。
Args:
type: 事件名
sender: 发送消息的人
message_chain: 消息内容
"""
type: str = 'StrangerMessage'
"""事件名。"""
sender: platform_entities.Friend
"""发送消息的人。"""
message_chain: platform_message.MessageChain
"""消息内容。"""

View File

@ -0,0 +1,817 @@
import itertools
import logging
from datetime import datetime
from enum import Enum
from pathlib import Path
import typing
import pydantic
import pydantic.main
from . import entities as platform_entities
from .base import PlatformBaseModel, PlatformIndexedMetaclass, PlatformIndexedModel
logger = logging.getLogger(__name__)
class MessageComponentMetaclass(PlatformIndexedMetaclass):
"""消息组件元类。"""
__message_component__ = None
def __new__(cls, name, bases, attrs, **kwargs):
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
if name == 'MessageComponent':
cls.__message_component__ = new_cls
if not cls.__message_component__:
return new_cls
for base in bases:
if issubclass(base, cls.__message_component__):
# 获取字段名
if hasattr(new_cls, '__fields__'):
# 忽略 type 字段
new_cls.__parameter_names__ = list(new_cls.__fields__)[1:]
else:
new_cls.__parameter_names__ = []
break
return new_cls
class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass):
"""消息组件。"""
type: str
"""消息组件类型。"""
def __str__(self):
return ''
def __repr__(self):
return self.__class__.__name__ + '(' + ', '.join(
(
f'{k}={repr(v)}'
for k, v in self.__dict__.items() if k != 'type' and v
)
) + ')'
def __init__(self, *args, **kwargs):
# 解析参数列表,将位置参数转化为具名参数
parameter_names = self.__parameter_names__
if len(args) > len(parameter_names):
raise TypeError(
f'`{self.type}`需要{len(parameter_names)}个参数,但传入了{len(args)}个。'
)
for name, value in zip(parameter_names, args):
if name in kwargs:
raise TypeError(f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。')
kwargs[name] = value
super().__init__(**kwargs)
TMessageComponent = typing.TypeVar('TMessageComponent', bound=MessageComponent)
class MessageChain(PlatformBaseModel):
"""消息链。
一个构造消息链的例子
```py
message_chain = MessageChain([
AtAll(),
Plain("Hello World!"),
])
```
`Plain` 可以省略
```py
message_chain = MessageChain([
AtAll(),
"Hello World!",
])
```
在调用 API 参数中需要 MessageChain 也可以使用 `List[MessageComponent]` 代替
例如以下两种写法是等价的
```py
await bot.send_friend_message(12345678, [
Plain("Hello World!")
])
```
```py
await bot.send_friend_message(12345678, MessageChain([
Plain("Hello World!")
]))
```
可以使用 `in` 运算检查消息链中
1. 是否有某个消息组件
2. 是否有某个类型的消息组件
```py
if AtAll in message_chain:
print('AtAll')
if At(bot.qq) in message_chain:
print('At Me')
```
消息链对索引操作进行了增强以消息组件类型为索引获取消息链中的全部该类型的消息组件
```py
plain_list = message_chain[Plain]
'[Plain("Hello World!")]'
```
可以用加号连接两个消息链
```py
MessageChain(['Hello World!']) + MessageChain(['Goodbye World!'])
# 返回 MessageChain([Plain("Hello World!"), Plain("Goodbye World!")])
```
"""
__root__: typing.List[MessageComponent]
@staticmethod
def _parse_message_chain(msg_chain: typing.Iterable):
result = []
for msg in msg_chain:
if isinstance(msg, dict):
result.append(MessageComponent.parse_subtype(msg))
elif isinstance(msg, MessageComponent):
result.append(msg)
elif isinstance(msg, str):
result.append(Plain(msg))
else:
raise TypeError(
f"消息链中元素需为 dict 或 str 或 MessageComponent当前类型{type(msg)}"
)
return result
@pydantic.validator('__root__', always=True, pre=True)
def _parse_component(cls, msg_chain):
if isinstance(msg_chain, (str, MessageComponent)):
msg_chain = [msg_chain]
if not msg_chain:
msg_chain = []
return cls._parse_message_chain(msg_chain)
@classmethod
def parse_obj(cls, msg_chain: typing.Iterable):
"""通过列表形式的消息链,构造对应的 `MessageChain` 对象。
Args:
msg_chain: 列表形式的消息链
"""
result = cls._parse_message_chain(msg_chain)
return cls(__root__=result)
def __init__(self, __root__: typing.Iterable[MessageComponent] = None):
super().__init__(__root__=__root__)
def __str__(self):
return "".join(str(component) for component in self.__root__)
def __repr__(self):
return f'{self.__class__.__name__}({self.__root__!r})'
def __iter__(self):
yield from self.__root__
def get_first(self,
t: typing.Type[TMessageComponent]) -> typing.Optional[TMessageComponent]:
"""获取消息链中第一个符合类型的消息组件。"""
for component in self:
if isinstance(component, t):
return component
return None
@typing.overload
def __getitem__(self, index: int) -> MessageComponent:
...
@typing.overload
def __getitem__(self, index: slice) -> typing.List[MessageComponent]:
...
@typing.overload
def __getitem__(self,
index: typing.Type[TMessageComponent]) -> typing.List[TMessageComponent]:
...
@typing.overload
def __getitem__(
self, index: typing.Tuple[typing.Type[TMessageComponent], int]
) -> typing.List[TMessageComponent]:
...
def __getitem__(
self, index: typing.Union[int, slice, typing.Type[TMessageComponent],
typing.Tuple[typing.Type[TMessageComponent], int]]
) -> typing.Union[MessageComponent, typing.List[MessageComponent],
typing.List[TMessageComponent]]:
return self.get(index)
def __setitem__(
self, key: typing.Union[int, slice],
value: typing.Union[MessageComponent, str, typing.Iterable[typing.Union[MessageComponent,
str]]]
):
if isinstance(value, str):
value = Plain(value)
if isinstance(value, typing.Iterable):
value = (Plain(c) if isinstance(c, str) else c for c in value)
self.__root__[key] = value # type: ignore
def __delitem__(self, key: typing.Union[int, slice]):
del self.__root__[key]
def __reversed__(self) -> typing.Iterable[MessageComponent]:
return reversed(self.__root__)
def has(
self, sub: typing.Union[MessageComponent, typing.Type[MessageComponent],
'MessageChain', str]
) -> bool:
"""判断消息链中:
1. 是否有某个消息组件
2. 是否有某个类型的消息组件
Args:
sub (`Union[MessageComponent, Type[MessageComponent], 'MessageChain', str]`):
若为 `MessageComponent`则判断该组件是否在消息链中
若为 `Type[MessageComponent]`则判断该组件类型是否在消息链中
Returns:
bool: 是否找到
"""
if isinstance(sub, type): # 检测消息链中是否有某种类型的对象
for i in self:
if type(i) is sub:
return True
return False
if isinstance(sub, MessageComponent): # 检查消息链中是否有某个组件
for i in self:
if i == sub:
return True
return False
raise TypeError(f"类型不匹配,当前类型:{type(sub)}")
def __contains__(self, sub) -> bool:
return self.has(sub)
def __ge__(self, other):
return other in self
def __len__(self) -> int:
return len(self.__root__)
def __add__(
self, other: typing.Union['MessageChain', MessageComponent, str]
) -> 'MessageChain':
if isinstance(other, MessageChain):
return self.__class__(self.__root__ + other.__root__)
if isinstance(other, str):
return self.__class__(self.__root__ + [Plain(other)])
if isinstance(other, MessageComponent):
return self.__class__(self.__root__ + [other])
return NotImplemented
def __radd__(self, other: typing.Union[MessageComponent, str]) -> 'MessageChain':
if isinstance(other, MessageComponent):
return self.__class__([other] + self.__root__)
if isinstance(other, str):
return self.__class__(
[typing.cast(MessageComponent, Plain(other))] + self.__root__
)
return NotImplemented
def __mul__(self, other: int):
if isinstance(other, int):
return self.__class__(self.__root__ * other)
return NotImplemented
def __rmul__(self, other: int):
return self.__mul__(other)
def __iadd__(self, other: typing.Iterable[typing.Union[MessageComponent, str]]):
self.extend(other)
def __imul__(self, other: int):
if isinstance(other, int):
self.__root__ *= other
return NotImplemented
def index(
self,
x: typing.Union[MessageComponent, typing.Type[MessageComponent]],
i: int = 0,
j: int = -1
) -> int:
"""返回 x 在消息链中首次出现项的索引号(索引号在 i 或其后且在 j 之前)。
Args:
x (`Union[MessageComponent, Type[MessageComponent]]`):
要查找的消息元素或消息元素类型
i: 从哪个位置开始查找
j: 查找到哪个位置结束
Returns:
int: 如果找到则返回索引号
Raises:
ValueError: 没有找到
TypeError: 类型不匹配
"""
if isinstance(x, type):
l = len(self)
if i < 0:
i += l
if i < 0:
i = 0
if j < 0:
j += l
if j > l:
j = l
for index in range(i, j):
if type(self[index]) is x:
return index
raise ValueError("消息链中不存在该类型的组件。")
if isinstance(x, MessageComponent):
return self.__root__.index(x, i, j)
raise TypeError(f"类型不匹配,当前类型:{type(x)}")
def count(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]) -> int:
"""返回消息链中 x 出现的次数。
Args:
x (`Union[MessageComponent, Type[MessageComponent]]`):
要查找的消息元素或消息元素类型
Returns:
int: 次数
"""
if isinstance(x, type):
return sum(1 for i in self if type(i) is x)
if isinstance(x, MessageComponent):
return self.__root__.count(x)
raise TypeError(f"类型不匹配,当前类型:{type(x)}")
def extend(self, x: typing.Iterable[typing.Union[MessageComponent, str]]):
"""将另一个消息链中的元素添加到消息链末尾。
Args:
x: 另一个消息链也可为消息元素或字符串元素的序列
"""
self.__root__.extend(Plain(c) if isinstance(c, str) else c for c in x)
def append(self, x: typing.Union[MessageComponent, str]):
"""将一个消息元素或字符串元素添加到消息链末尾。
Args:
x: 消息元素或字符串元素
"""
self.__root__.append(Plain(x) if isinstance(x, str) else x)
def insert(self, i: int, x: typing.Union[MessageComponent, str]):
"""将一个消息元素或字符串添加到消息链中指定位置。
Args:
i: 插入位置
x: 消息元素或字符串元素
"""
self.__root__.insert(i, Plain(x) if isinstance(x, str) else x)
def pop(self, i: int = -1) -> MessageComponent:
"""从消息链中移除并返回指定位置的元素。
Args:
i: 移除位置默认为末尾
Returns:
MessageComponent: 移除的元素
"""
return self.__root__.pop(i)
def remove(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]):
"""从消息链中移除指定元素或指定类型的一个元素。
Args:
x: 指定的元素或元素类型
"""
if isinstance(x, type):
self.pop(self.index(x))
if isinstance(x, MessageComponent):
self.__root__.remove(x)
def exclude(
self,
x: typing.Union[MessageComponent, typing.Type[MessageComponent]],
count: int = -1
) -> 'MessageChain':
"""返回移除指定元素或指定类型的元素后剩余的消息链。
Args:
x: 指定的元素或元素类型
count: 至多移除的数量默认为全部移除
Returns:
MessageChain: 剩余的消息链
"""
def _exclude():
nonlocal count
x_is_type = isinstance(x, type)
for c in self:
if count > 0 and ((x_is_type and type(c) is x) or c == x):
count -= 1
continue
yield c
return self.__class__(_exclude())
def reverse(self):
"""将消息链原地翻转。"""
self.__root__.reverse()
@classmethod
def join(cls, *args: typing.Iterable[typing.Union[str, MessageComponent]]):
return cls(
Plain(c) if isinstance(c, str) else c
for c in itertools.chain(*args)
)
@property
def source(self) -> typing.Optional['Source']:
"""获取消息链中的 `Source` 对象。"""
return self.get_first(Source)
@property
def message_id(self) -> int:
"""获取消息链的 message_id若无法获取返回 -1。"""
source = self.source
return source.id if source else -1
TMessage = typing.Union[MessageChain, typing.Iterable[typing.Union[MessageComponent, str]],
MessageComponent, str]
"""可以转化为 MessageChain 的类型。"""
class Source(MessageComponent):
"""源。包含消息的基本信息。"""
type: str = "Source"
"""消息组件类型。"""
id: int
"""消息的识别号用于引用回复Source 类型永远为 MessageChain 的第一个元素)。"""
time: datetime
"""消息时间。"""
class Plain(MessageComponent):
"""纯文本。"""
type: str = "Plain"
"""消息组件类型。"""
text: str
"""文字消息。"""
def __str__(self):
return self.text
def __repr__(self):
return f'Plain({self.text!r})'
class Quote(MessageComponent):
"""引用。"""
type: str = "Quote"
"""消息组件类型。"""
id: typing.Optional[int] = None
"""被引用回复的原消息的 message_id。"""
group_id: typing.Optional[int] = None
"""被引用回复的原消息所接收的群号当为好友消息时为0。"""
sender_id: typing.Optional[int] = None
"""被引用回复的原消息的发送者的QQ号。"""
target_id: typing.Optional[int] = None
"""被引用回复的原消息的接收者者的QQ号或群号"""
origin: MessageChain
"""被引用回复的原消息的消息链对象。"""
@pydantic.validator("origin", always=True, pre=True)
def origin_formater(cls, v):
return MessageChain.parse_obj(v)
class At(MessageComponent):
"""At某人。"""
type: str = "At"
"""消息组件类型。"""
target: int
"""群员 QQ 号。"""
display: typing.Optional[str] = None
"""At时显示的文字发送消息时无效自动使用群名片。"""
def __eq__(self, other):
return isinstance(other, At) and self.target == other.target
def __str__(self):
return f"@{self.display or self.target}"
class AtAll(MessageComponent):
"""At全体。"""
type: str = "AtAll"
"""消息组件类型。"""
def __str__(self):
return "@全体成员"
class Image(MessageComponent):
"""图片。"""
type: str = "Image"
"""消息组件类型。"""
image_id: typing.Optional[str] = None
"""图片的 image_id群图片与好友图片格式不同。不为空时将忽略 url 属性。"""
url: typing.Optional[pydantic.HttpUrl] = None
"""图片的 URL发送时可作网络图片的链接接收时为腾讯图片服务器的链接可用于图片下载。"""
path: typing.Union[str, Path, None] = None
"""图片的路径,发送本地图片。"""
base64: typing.Optional[str] = None
"""图片的 Base64 编码。"""
def __eq__(self, other):
return isinstance(
other, Image
) and self.type == other.type and self.uuid == other.uuid
def __str__(self):
return '[图片]'
@pydantic.validator('path')
def validate_path(cls, path: typing.Union[str, Path, None]):
"""修复 path 参数的行为,使之相对于 QChatGPT 的启动路径。"""
if path:
try:
return str(Path(path).resolve(strict=True))
except FileNotFoundError:
raise ValueError(f"无效路径:{path}")
else:
return path
@property
def uuid(self):
image_id = self.image_id
if image_id[0] == '{': # 群图片
image_id = image_id[1:37]
elif image_id[0] == '/': # 好友图片
image_id = image_id[1:]
return image_id
async def download(
self,
filename: typing.Union[str, Path, None] = None,
directory: typing.Union[str, Path, None] = None,
determine_type: bool = True
):
"""下载图片到本地。
Args:
filename: 下载到本地的文件路径 `directory` 二选一
directory: 下载到本地的文件夹路径 `filename` 二选一
determine_type: 是否自动根据图片类型确定拓展名默认为 True
"""
if not self.url:
logger.warning(f'图片 `{self.uuid}` 无 url 参数,下载失败。')
return
import httpx
async with httpx.AsyncClient() as client:
response = await client.get(self.url)
response.raise_for_status()
content = response.content
if filename:
path = Path(filename)
if determine_type:
import imghdr
path = path.with_suffix(
'.' + str(imghdr.what(None, content))
)
path.parent.mkdir(parents=True, exist_ok=True)
elif directory:
import imghdr
path = Path(directory)
path.mkdir(parents=True, exist_ok=True)
path = path / f'{self.uuid}.{imghdr.what(None, content)}'
else:
raise ValueError("请指定文件路径或文件夹路径!")
import aiofiles
async with aiofiles.open(path, 'wb') as f:
await f.write(content)
return path
@classmethod
async def from_local(
cls,
filename: typing.Union[str, Path, None] = None,
content: typing.Optional[bytes] = None,
) -> "Image":
"""从本地文件路径加载图片,以 base64 的形式传递。
Args:
filename: 从本地文件路径加载图片 `content` 二选一
content: 从本地文件内容加载图片 `filename` 二选一
Returns:
Image: 图片对象
"""
if content:
pass
elif filename:
path = Path(filename)
import aiofiles
async with aiofiles.open(path, 'rb') as f:
content = await f.read()
else:
raise ValueError("请指定图片路径或图片内容!")
import base64
img = cls(base64=base64.b64encode(content).decode())
return img
@classmethod
def from_unsafe_path(cls, path: typing.Union[str, Path]) -> "Image":
"""从不安全的路径加载图片。
Args:
path: 从不安全的路径加载图片
Returns:
Image: 图片对象
"""
return cls.construct(path=str(path))
class Unknown(MessageComponent):
"""未知。"""
type: str = "Unknown"
"""消息组件类型。"""
text: str
"""文本。"""
class Voice(MessageComponent):
"""语音。"""
type: str = "Voice"
"""消息组件类型。"""
voice_id: typing.Optional[str] = None
"""语音的 voice_id不为空时将忽略 url 属性。"""
url: typing.Optional[str] = None
"""语音的 URL发送时可作网络语音的链接接收时为腾讯语音服务器的链接可用于语音下载。"""
path: typing.Optional[str] = None
"""语音的路径,发送本地语音。"""
base64: typing.Optional[str] = None
"""语音的 Base64 编码。"""
length: typing.Optional[int] = None
"""语音的长度,单位为秒。"""
@pydantic.validator('path')
def validate_path(cls, path: typing.Optional[str]):
"""修复 path 参数的行为,使之相对于 QChatGPT 的启动路径。"""
if path:
try:
return str(Path(path).resolve(strict=True))
except FileNotFoundError:
raise ValueError(f"无效路径:{path}")
else:
return path
def __str__(self):
return '[语音]'
async def download(
self,
filename: typing.Union[str, Path, None] = None,
directory: typing.Union[str, Path, None] = None
):
"""下载语音到本地。
语音采用 silk v3 格式silk 格式的编码解码请使用 [graiax-silkcoder](https://pypi.org/project/graiax-silkcoder/)
Args:
filename: 下载到本地的文件路径 `directory` 二选一
directory: 下载到本地的文件夹路径 `filename` 二选一
"""
if not self.url:
logger.warning(f'语音 `{self.voice_id}` 无 url 参数,下载失败。')
return
import httpx
async with httpx.AsyncClient() as client:
response = await client.get(self.url)
response.raise_for_status()
content = response.content
if filename:
path = Path(filename)
path.parent.mkdir(parents=True, exist_ok=True)
elif directory:
path = Path(directory)
path.mkdir(parents=True, exist_ok=True)
path = path / f'{self.voice_id}.silk'
else:
raise ValueError("请指定文件路径或文件夹路径!")
import aiofiles
async with aiofiles.open(path, 'wb') as f:
await f.write(content)
@classmethod
async def from_local(
cls,
filename: typing.Union[str, Path, None] = None,
content: typing.Optional[bytes] = None,
) -> "Voice":
"""从本地文件路径加载语音,以 base64 的形式传递。
Args:
filename: 从本地文件路径加载语音 `content` 二选一
content: 从本地文件内容加载语音 `filename` 二选一
"""
if content:
pass
if filename:
path = Path(filename)
import aiofiles
async with aiofiles.open(path, 'rb') as f:
content = await f.read()
else:
raise ValueError("请指定语音路径或语音内容!")
import base64
img = cls(base64=base64.b64encode(content).decode())
return img
class ForwardMessageNode(pydantic.BaseModel):
"""合并转发中的一条消息。"""
sender_id: typing.Optional[int] = None
"""发送人QQ号。"""
sender_name: typing.Optional[str] = None
"""显示名称。"""
message_chain: typing.Optional[MessageChain] = None
"""消息内容。"""
message_id: typing.Optional[int] = None
"""消息的 message_id可以只使用此属性从缓存中读取消息内容。"""
time: typing.Optional[datetime] = None
"""发送时间。"""
@pydantic.validator('message_chain', check_fields=False)
def _validate_message_chain(cls, value: typing.Union[MessageChain, list]):
if isinstance(value, list):
return MessageChain.parse_obj(value)
return value
@classmethod
def create(
cls, sender: typing.Union[platform_entities.Friend, platform_entities.GroupMember], message: MessageChain
) -> 'ForwardMessageNode':
"""从消息链生成转发消息。
Args:
sender: 发送人
message: 消息内容
Returns:
ForwardMessageNode: 生成的一条消息
"""
return ForwardMessageNode(
sender_id=sender.id,
sender_name=sender.get_name(),
message_chain=message
)
class Forward(MessageComponent):
"""合并转发。"""
type: str = "Forward"
"""消息组件类型。"""
node_list: typing.List[ForwardMessageNode]
"""转发消息节点列表。"""
def __init__(self, *args, **kwargs):
if len(args) == 1:
self.node_list = args[0]
super().__init__(**kwargs)
super().__init__(*args, **kwargs)
def __str__(self):
return '[聊天记录]'
class File(MessageComponent):
"""文件。"""
type: str = "File"
"""消息组件类型。"""
id: str
"""文件识别 ID。"""
name: str
"""文件名称。"""
size: int
"""文件大小。"""
def __str__(self):
return f'[文件]{self.name}'

View File

@ -3,11 +3,11 @@ from __future__ import annotations
import typing import typing
import abc import abc
import pydantic import pydantic
import mirai
from . import events from . import events
from ..provider.tools import entities as tools_entities from ..provider.tools import entities as tools_entities
from ..core import app from ..core import app
from ..platform.types import message as platform_message
def register( def register(
@ -174,11 +174,11 @@ class EventContext:
self.__return_value__[key] = [] self.__return_value__[key] = []
self.__return_value__[key].append(ret) self.__return_value__[key].append(ret)
async def reply(self, message_chain: mirai.MessageChain): async def reply(self, message_chain: platform_message.MessageChain):
"""回复此次消息请求 """回复此次消息请求
Args: Args:
message_chain (mirai.MessageChain): YiriMirai库的消息链若用户使用的不是 YiriMirai 适配器程序也能自动转换为目标消息链 message_chain (platform.types.MessageChain): 源平台的消息链若用户使用的不是源平台适配器程序也能自动转换为目标平台消息链
""" """
await self.host.ap.platform_mgr.send( await self.host.ap.platform_mgr.send(
event=self.event.query.message_event, event=self.event.query.message_event,
@ -190,14 +190,14 @@ class EventContext:
self, self,
target_type: str, target_type: str,
target_id: str, target_id: str,
message: mirai.MessageChain message: platform_message.MessageChain
): ):
"""主动发送消息 """主动发送消息
Args: Args:
target_type (str): 目标类型`person``group` target_type (str): 目标类型`person``group`
target_id (str): 目标ID target_id (str): 目标ID
message (mirai.MessageChain): YiriMirai库的消息链若用户使用的不是 YiriMirai 适配器程序也能自动转换为目标消息链 message (platform.types.MessageChain): 源平台的消息链若用户使用的不是源平台适配器程序也能自动转换为目标平台消息链
""" """
await self.event.query.adapter.send_message( await self.event.query.adapter.send_message(
target_type=target_type, target_type=target_type,

View File

@ -3,10 +3,10 @@ from __future__ import annotations
import typing import typing
import pydantic import pydantic
import mirai
from ..core import entities as core_entities from ..core import entities as core_entities
from ..provider import entities as llm_entities from ..provider import entities as llm_entities
from ..platform.types import message as platform_message
class BaseEventModel(pydantic.BaseModel): class BaseEventModel(pydantic.BaseModel):
@ -31,7 +31,7 @@ class PersonMessageReceived(BaseEventModel):
sender_id: int sender_id: int
"""发送者ID(QQ号)""" """发送者ID(QQ号)"""
message_chain: mirai.MessageChain message_chain: platform_message.MessageChain
class GroupMessageReceived(BaseEventModel): class GroupMessageReceived(BaseEventModel):
@ -43,7 +43,7 @@ class GroupMessageReceived(BaseEventModel):
sender_id: int sender_id: int
message_chain: mirai.MessageChain message_chain: platform_message.MessageChain
class PersonNormalMessageReceived(BaseEventModel): class PersonNormalMessageReceived(BaseEventModel):

View File

@ -48,6 +48,8 @@ class PluginManager:
# 按优先级倒序 # 按优先级倒序
self.plugins.sort(key=lambda x: x.priority, reverse=True) self.plugins.sort(key=lambda x: x.priority, reverse=True)
self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugins}')
async def initialize_plugins(self): async def initialize_plugins(self):
for plugin in self.plugins: for plugin in self.plugins:
try: try:

View File

@ -45,6 +45,7 @@ class SettingManager:
for plugin_container in plugin_containers: for plugin_container in plugin_containers:
if plugin_container.plugin_name == value['name']: if plugin_container.plugin_name == value['name']:
plugin_container.set_from_setting_dict(value) plugin_container.set_from_setting_dict(value)
break
self.settings.data = { self.settings.data = {
'plugins': [ 'plugins': [

View File

@ -4,7 +4,8 @@ import typing
import enum import enum
import pydantic import pydantic
import mirai
from ..platform.types import message as platform_message
class FunctionCall(pydantic.BaseModel): class FunctionCall(pydantic.BaseModel):
@ -73,14 +74,14 @@ class Message(pydantic.BaseModel):
def readable_str(self) -> str: def readable_str(self) -> str:
if self.content is not None: if self.content is not None:
return str(self.role) + ": " + str(self.get_content_mirai_message_chain()) return str(self.role) + ": " + str(self.get_content_platform_message_chain())
elif self.tool_calls is not None: elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}' return f'调用工具: {self.tool_calls[0].id}'
else: else:
return '未知消息' return '未知消息'
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None: def get_content_platform_message_chain(self, prefix_text: str="") -> platform_message.MessageChain | None:
"""将内容转换为 Mirai MessageChain 对象 """将内容转换为平台消息 MessageChain 对象
Args: Args:
prefix_text (str): 首个文字组件的前缀文本 prefix_text (str): 首个文字组件的前缀文本
@ -89,15 +90,15 @@ class Message(pydantic.BaseModel):
if self.content is None: if self.content is None:
return None return None
elif isinstance(self.content, str): elif isinstance(self.content, str):
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)]) return platform_message.MessageChain([platform_message.Plain(prefix_text+self.content)])
elif isinstance(self.content, list): elif isinstance(self.content, list):
mc = [] mc = []
for ce in self.content: for ce in self.content:
if ce.type == 'text': if ce.type == 'text':
mc.append(mirai.Plain(ce.text)) mc.append(platform_message.Plain(ce.text))
elif ce.type == 'image_url': elif ce.type == 'image_url':
if ce.image_url.url.startswith("http"): if ce.image_url.url.startswith("http"):
mc.append(mirai.Image(url=ce.image_url.url)) mc.append(platform_message.Image(url=ce.image_url.url))
else: # base64 else: # base64
b64_str = ce.image_url.url b64_str = ce.image_url.url
@ -105,15 +106,15 @@ class Message(pydantic.BaseModel):
if b64_str.startswith("data:"): if b64_str.startswith("data:"):
b64_str = b64_str.split(",")[1] b64_str = b64_str.split(",")[1]
mc.append(mirai.Image(base64=b64_str)) mc.append(platform_message.Image(base64=b64_str))
# 找第一个文字组件 # 找第一个文字组件
if prefix_text: if prefix_text:
for i, c in enumerate(mc): for i, c in enumerate(mc):
if isinstance(c, mirai.Plain): if isinstance(c, platform_message.Plain):
mc[i] = mirai.Plain(prefix_text+c.text) mc[i] = platform_message.Plain(prefix_text+c.text)
break break
else: else:
mc.insert(0, mirai.Plain(prefix_text)) mc.insert(0, platform_message.Plain(prefix_text))
return mirai.MessageChain(mc) return platform_message.MessageChain(mc)

View File

@ -2,7 +2,6 @@ requests
openai>1.0.0 openai>1.0.0
anthropic anthropic
colorlog~=6.6.0 colorlog~=6.6.0
yiri-mirai-rc
aiocqhttp aiocqhttp
qq-botpy qq-botpy
nakuru-project-idk nakuru-project-idk