mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
Merge pull request #890 from RockChinQ/feat/more-platforms
Refactor: 移除 YiriMirai 组件
This commit is contained in:
commit
ea6a0af5a7
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
|
@ -8,10 +8,10 @@ body:
|
|||
label: 消息平台适配器
|
||||
description: "连接QQ使用的框架"
|
||||
options:
|
||||
- yiri-mirai(Mirai)
|
||||
- Nakuru(go-cqhttp)
|
||||
- aiocqhttp(使用 OneBot 协议接入的)
|
||||
- qq-botpy(QQ官方API)
|
||||
- yiri-mirai(Mirai)
|
||||
validations:
|
||||
required: false
|
||||
- type: input
|
||||
|
|
1
.github/dependabot.yml
vendored
1
.github/dependabot.yml
vendored
|
@ -10,5 +10,4 @@ updates:
|
|||
schedule:
|
||||
interval: "weekly"
|
||||
allow:
|
||||
- dependency-name: "yiri-mirai-rc"
|
||||
- dependency-name: "openai"
|
||||
|
|
3
.github/pull_request_template.md
vendored
3
.github/pull_request_template.md
vendored
|
@ -6,8 +6,11 @@
|
|||
|
||||
### PR 作者完成
|
||||
|
||||
*请在方括号间写`x`以打勾
|
||||
|
||||
- [ ] 阅读仓库[贡献指引](https://github.com/RockChinQ/QChatGPT/blob/master/CONTRIBUTING.md)了吗?
|
||||
- [ ] 与项目所有者沟通过了吗?
|
||||
- [ ] 我确定已自行测试所作的更改,确保功能符合预期。
|
||||
|
||||
### 项目所有者完成
|
||||
|
||||
|
|
|
@ -3,10 +3,10 @@ from __future__ import annotations
|
|||
import typing
|
||||
|
||||
import pydantic
|
||||
import mirai
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from . import errors, operator
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
|
||||
class CommandReturn(pydantic.BaseModel):
|
||||
|
@ -17,7 +17,7 @@ class CommandReturn(pydantic.BaseModel):
|
|||
"""文本
|
||||
"""
|
||||
|
||||
image: typing.Optional[mirai.Image] = None
|
||||
image: typing.Optional[platform_message.Image] = None
|
||||
"""弃用"""
|
||||
|
||||
image_url: typing.Optional[str] = None
|
||||
|
|
|
@ -5,7 +5,6 @@ required_deps = {
|
|||
"openai": "openai",
|
||||
"anthropic": "anthropic",
|
||||
"colorlog": "colorlog",
|
||||
"mirai": "yiri-mirai-rc",
|
||||
"aiocqhttp": "aiocqhttp",
|
||||
"botpy": "qq-botpy",
|
||||
"PIL": "pillow",
|
||||
|
|
|
@ -6,13 +6,15 @@ import datetime
|
|||
import asyncio
|
||||
|
||||
import pydantic
|
||||
import mirai
|
||||
|
||||
from ..provider import entities as llm_entities
|
||||
from ..provider.modelmgr import entities
|
||||
from ..provider.sysprompt import entities as sysprompt_entities
|
||||
from ..provider.tools import entities as tools_entities
|
||||
from ..platform import adapter as msadapter
|
||||
from ..platform.types import message as platform_message
|
||||
from ..platform.types import events as platform_events
|
||||
from ..platform.types import entities as platform_entities
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
|
@ -40,10 +42,10 @@ class Query(pydantic.BaseModel):
|
|||
sender_id: int
|
||||
"""发送者ID,platform处理阶段设置"""
|
||||
|
||||
message_event: mirai.MessageEvent
|
||||
message_event: platform_events.MessageEvent
|
||||
"""事件,platform收到的原始事件"""
|
||||
|
||||
message_chain: mirai.MessageChain
|
||||
message_chain: platform_message.MessageChain
|
||||
"""消息链,platform收到的原始消息链"""
|
||||
|
||||
adapter: msadapter.MessageSourceAdapter
|
||||
|
@ -67,10 +69,10 @@ class Query(pydantic.BaseModel):
|
|||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
|
||||
"""使用的函数,由前置处理器阶段设置"""
|
||||
|
||||
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[mirai.MessageChain]] = []
|
||||
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] = []
|
||||
"""由Process阶段生成的回复消息对象列表"""
|
||||
|
||||
resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None
|
||||
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
|
||||
"""回复消息链,从resp_messages包装而得"""
|
||||
|
||||
# ======= 内部保留 =======
|
||||
|
@ -108,7 +110,7 @@ class Session(pydantic.BaseModel):
|
|||
|
||||
using_conversation: typing.Optional[Conversation] = None
|
||||
|
||||
conversations: typing.Optional[list[Conversation]] = []
|
||||
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
|
|
|
@ -1,9 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
import mirai.models
|
||||
import mirai.models.message
|
||||
|
||||
from ...core import app
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
|
@ -12,6 +8,9 @@ from ...config import manager as cfg_mgr
|
|||
from . import filter as filter_model, entities as filter_entities
|
||||
from .filters import cntignore, banwords, baiduexamine
|
||||
from ...provider import entities as llm_entities
|
||||
from ...platform.types import message as platform_message
|
||||
from ...platform.types import events as platform_events
|
||||
from ...platform.types import entities as platform_entities
|
||||
|
||||
|
||||
@stage.stage_class('PostContentFilterStage')
|
||||
|
@ -89,8 +88,8 @@ class ContentFilterStage(stage.PipelineStage):
|
|||
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
||||
message = result.replacement
|
||||
|
||||
query.message_chain = mirai.MessageChain(
|
||||
mirai.Plain(message)
|
||||
query.message_chain = platform_message.MessageChain(
|
||||
platform_message.Plain(message)
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
|
@ -148,7 +147,7 @@ class ContentFilterStage(stage.PipelineStage):
|
|||
|
||||
contain_non_text = False
|
||||
|
||||
text_components = [mirai.Plain, mirai.models.message.Source]
|
||||
text_components = [platform_message.Plain, platform_message.Source]
|
||||
|
||||
for me in query.message_chain:
|
||||
if type(me) not in text_components:
|
||||
|
|
|
@ -4,11 +4,11 @@ import asyncio
|
|||
import typing
|
||||
import traceback
|
||||
|
||||
import mirai
|
||||
|
||||
from ..core import app, entities
|
||||
from . import entities as pipeline_entities
|
||||
from ..plugin import events
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
|
||||
class Controller:
|
||||
|
@ -73,11 +73,11 @@ class Controller:
|
|||
# 处理str类型
|
||||
|
||||
if isinstance(result.user_notice, str):
|
||||
result.user_notice = mirai.MessageChain(
|
||||
mirai.Plain(result.user_notice)
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
platform_message.Plain(result.user_notice)
|
||||
)
|
||||
elif isinstance(result.user_notice, list):
|
||||
result.user_notice = mirai.MessageChain(
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
*result.user_notice
|
||||
)
|
||||
|
||||
|
|
|
@ -4,8 +4,7 @@ import enum
|
|||
import typing
|
||||
|
||||
import pydantic
|
||||
import mirai
|
||||
import mirai.models.message as mirai_message
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
from ..core import entities
|
||||
|
||||
|
@ -25,13 +24,9 @@ class StageProcessResult(pydantic.BaseModel):
|
|||
|
||||
new_query: entities.Query
|
||||
|
||||
user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
|
||||
user_notice: typing.Optional[typing.Union[str, list[platform_message.MessageComponent], platform_message.MessageChain, None]] = []
|
||||
"""只要设置了就会发送给用户"""
|
||||
|
||||
# TODO delete
|
||||
# admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
|
||||
"""只要设置了就会发送给管理员"""
|
||||
|
||||
console_notice: typing.Optional[str] = ''
|
||||
"""只要设置了就会输出到控制台"""
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ import os
|
|||
import traceback
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from mirai.models.message import MessageComponent, Plain, MessageChain
|
||||
|
||||
from ...core import app
|
||||
from . import strategy
|
||||
|
@ -11,6 +10,7 @@ from .strategies import image, forward
|
|||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
|
||||
@stage.stage_class("LongTextProcessStage")
|
||||
|
@ -63,14 +63,14 @@ class LongTextProcessStage(stage.PipelineStage):
|
|||
contains_non_plain = False
|
||||
|
||||
for msg in query.resp_message_chain[-1]:
|
||||
if not isinstance(msg, Plain):
|
||||
if not isinstance(msg, platform_message.Plain):
|
||||
contains_non_plain = True
|
||||
break
|
||||
|
||||
if contains_non_plain:
|
||||
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
|
||||
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
|
||||
query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
|
||||
query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
|
|
@ -2,15 +2,14 @@
|
|||
from __future__ import annotations
|
||||
import typing
|
||||
|
||||
from mirai.models import MessageChain
|
||||
from mirai.models.message import MessageComponent, ForwardMessageNode
|
||||
from mirai.models.base import MiraiBaseModel
|
||||
import pydantic
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
class ForwardMessageDiaplay(MiraiBaseModel):
|
||||
class ForwardMessageDiaplay(pydantic.BaseModel):
|
||||
title: str = "群聊的聊天记录"
|
||||
brief: str = "[聊天记录]"
|
||||
source: str = "聊天记录"
|
||||
|
@ -18,13 +17,13 @@ class ForwardMessageDiaplay(MiraiBaseModel):
|
|||
summary: str = "查看x条转发消息"
|
||||
|
||||
|
||||
class Forward(MessageComponent):
|
||||
class Forward(platform_message.MessageComponent):
|
||||
"""合并转发。"""
|
||||
type: str = "Forward"
|
||||
"""消息组件类型。"""
|
||||
display: ForwardMessageDiaplay
|
||||
"""显示信息"""
|
||||
node_list: typing.List[ForwardMessageNode]
|
||||
node_list: typing.List[platform_message.ForwardMessageNode]
|
||||
"""转发消息节点列表。"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if len(args) == 1:
|
||||
|
@ -39,7 +38,7 @@ class Forward(MessageComponent):
|
|||
@strategy_model.strategy_class("forward")
|
||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
display = ForwardMessageDiaplay(
|
||||
title="群聊的聊天记录",
|
||||
brief="[聊天记录]",
|
||||
|
@ -49,10 +48,10 @@ class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
|||
)
|
||||
|
||||
node_list = [
|
||||
ForwardMessageNode(
|
||||
platform_message.ForwardMessageNode(
|
||||
sender_id=query.adapter.bot_account_id,
|
||||
sender_name='QQ用户',
|
||||
message_chain=MessageChain([message])
|
||||
message_chain=platform_message.MessageChain([message])
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
@ -8,8 +8,7 @@ import re
|
|||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from mirai.models import MessageChain, Image as ImageComponent
|
||||
from mirai.models.message import MessageComponent
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
|
@ -23,7 +22,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
|||
async def initialize(self):
|
||||
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
|
||||
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
img_path = self.text_to_image(
|
||||
text_str=message,
|
||||
save_as='temp/{}.png'.format(int(time.time()))
|
||||
|
@ -46,7 +45,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
|||
os.remove(compressed_path)
|
||||
|
||||
return [
|
||||
ImageComponent(
|
||||
platform_message.Image(
|
||||
base64=b64.decode('utf-8'),
|
||||
)
|
||||
]
|
||||
|
|
|
@ -2,11 +2,10 @@ from __future__ import annotations
|
|||
import abc
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
from mirai.models.message import MessageComponent
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
|
||||
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
|
||||
|
@ -51,7 +50,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
|||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
"""处理长文本
|
||||
|
||||
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
||||
|
@ -61,6 +60,6 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
|||
query (core_entities.Query): 此次请求的上下文对象
|
||||
|
||||
Returns:
|
||||
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表
|
||||
list[platform_message.MessageComponent]: 转换后的 平台 消息组件列表
|
||||
"""
|
||||
return []
|
||||
|
|
|
@ -2,10 +2,11 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
|
||||
import mirai
|
||||
|
||||
from ..core import entities
|
||||
from ..platform import adapter as msadapter
|
||||
from ..platform.types import message as platform_message
|
||||
from ..platform.types import events as platform_events
|
||||
|
||||
|
||||
class QueryPool:
|
||||
|
@ -30,8 +31,8 @@ class QueryPool:
|
|||
launcher_type: entities.LauncherTypes,
|
||||
launcher_id: int,
|
||||
sender_id: int,
|
||||
message_event: mirai.MessageEvent,
|
||||
message_chain: mirai.MessageChain,
|
||||
message_event: platform_events.MessageEvent,
|
||||
message_chain: platform_message.MessageChain,
|
||||
adapter: msadapter.MessageSourceAdapter
|
||||
) -> entities.Query:
|
||||
async with self.condition:
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...provider import entities as llm_entities
|
||||
from ...plugin import events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
|
||||
@stage.stage_class("PreProcessor")
|
||||
|
@ -55,11 +55,11 @@ class PreProcessor(stage.PipelineStage):
|
|||
content_list = []
|
||||
|
||||
for me in query.message_chain:
|
||||
if isinstance(me, mirai.Plain):
|
||||
if isinstance(me, platform_message.Plain):
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_text(me.text)
|
||||
)
|
||||
elif isinstance(me, mirai.Image):
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported:
|
||||
if me.url is not None:
|
||||
content_list.append(
|
||||
|
|
|
@ -5,7 +5,6 @@ import time
|
|||
import traceback
|
||||
import json
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
|
@ -13,6 +12,8 @@ from ....core import entities as core_entities
|
|||
from ....provider import entities as llm_entities, runnermgr
|
||||
from ....plugin import events
|
||||
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
|
@ -40,7 +41,7 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = mirai.MessageChain(event_ctx.event.reply)
|
||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from __future__ import annotations
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....provider import entities as llm_entities
|
||||
from ....plugin import events
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
class CommandHandler(handler.MessageHandler):
|
||||
|
@ -46,7 +46,7 @@ class CommandHandler(handler.MessageHandler):
|
|||
if event_ctx.is_prevented_default():
|
||||
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = mirai.MessageChain(event_ctx.event.reply)
|
||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
|
@ -63,8 +63,8 @@ class CommandHandler(handler.MessageHandler):
|
|||
else:
|
||||
|
||||
if event_ctx.event.alter is not None:
|
||||
query.message_chain = mirai.MessageChain([
|
||||
mirai.Plain(event_ctx.event.alter)
|
||||
query.message_chain = platform_message.MessageChain([
|
||||
platform_message.Plain(event_ctx.event.alter)
|
||||
])
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
import random
|
||||
import asyncio
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import pydantic
|
||||
import mirai
|
||||
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
|
||||
class RuleJudgeResult(pydantic.BaseModel):
|
||||
|
||||
matching: bool = False
|
||||
|
||||
replacement: mirai.MessageChain = None
|
||||
replacement: platform_message.MessageChain = None
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app
|
||||
from . import entities as rule_entities, rule
|
||||
|
|
|
@ -2,11 +2,11 @@ from __future__ import annotations
|
|||
import abc
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
|
||||
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
|
||||
|
||||
|
@ -35,7 +35,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
|
|||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
) -> entities.RuleJudgeResult:
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
@rule_model.rule_class("at-bot")
|
||||
|
@ -13,16 +13,16 @@ class AtBotRule(rule_model.GroupRespondRule):
|
|||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
) -> entities.RuleJudgeResult:
|
||||
|
||||
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
message_chain.remove(mirai.At(query.adapter.bot_account_id))
|
||||
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
if message_chain.has(mirai.At(query.adapter.bot_account_id)): # 回复消息时会at两次,检查并删除重复的
|
||||
message_chain.remove(mirai.At(query.adapter.bot_account_id))
|
||||
if message_chain.has(platform_message.At(query.adapter.bot_account_id)): # 回复消息时会at两次,检查并删除重复的
|
||||
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
@rule_model.rule_class("prefix")
|
||||
|
@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule):
|
|||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
) -> entities.RuleJudgeResult:
|
||||
|
@ -22,7 +22,7 @@ class PrefixRule(rule_model.GroupRespondRule):
|
|||
|
||||
# 查找第一个plain元素
|
||||
for me in message_chain:
|
||||
if isinstance(me, mirai.Plain):
|
||||
if isinstance(me, platform_message.Plain):
|
||||
me.text = me.text[len(prefix):]
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import random
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
@rule_model.rule_class("random")
|
||||
|
@ -13,7 +13,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
|
|||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
) -> entities.RuleJudgeResult:
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import re
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
@rule_model.rule_class("regexp")
|
||||
|
@ -13,7 +13,7 @@ class RegExpRule(rule_model.GroupRespondRule):
|
|||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
) -> entities.RuleJudgeResult:
|
||||
|
|
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities
|
||||
|
@ -10,6 +9,7 @@ from .. import stage, entities, stagemgr
|
|||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from ...plugin import events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
|
||||
@stage.stage_class("ResponseWrapper")
|
||||
|
@ -34,7 +34,7 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
"""
|
||||
|
||||
# 如果 resp_messages[-1] 已经是 MessageChain 了
|
||||
if isinstance(query.resp_messages[-1], mirai.MessageChain):
|
||||
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
|
||||
query.resp_message_chain.append(query.resp_messages[-1])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
|
@ -45,19 +45,14 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
else:
|
||||
|
||||
if query.resp_messages[-1].role == 'command':
|
||||
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] '))
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif query.resp_messages[-1].role == 'plugin':
|
||||
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
|
||||
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
|
||||
# else:
|
||||
# query.resp_message_chain.append(query.resp_messages[-1].content)
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain())
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
@ -72,7 +67,7 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
reply_text = ''
|
||||
|
||||
if result.content: # 有内容
|
||||
reply_text = str(result.get_content_mirai_message_chain())
|
||||
reply_text = str(result.get_content_platform_message_chain())
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
|
@ -96,11 +91,11 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
|
||||
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain.append(result.get_content_mirai_message_chain())
|
||||
query.resp_message_chain.append(result.get_content_platform_message_chain())
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
@ -113,7 +108,7 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
|
||||
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
|
||||
|
||||
if self.ap.platform_cfg.data['track-function-calls']:
|
||||
|
||||
|
@ -139,11 +134,11 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
|
||||
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
|
||||
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
|
|
@ -4,9 +4,10 @@ from __future__ import annotations
|
|||
import typing
|
||||
import abc
|
||||
|
||||
import mirai
|
||||
|
||||
from ..core import app
|
||||
from .types import message as platform_message
|
||||
from .types import events as platform_events
|
||||
|
||||
|
||||
preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = []
|
||||
|
@ -55,28 +56,28 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
|
|||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: mirai.MessageChain
|
||||
message: platform_message.MessageChain
|
||||
):
|
||||
"""主动发送消息
|
||||
|
||||
Args:
|
||||
target_type (str): 目标类型,`person`或`group`
|
||||
target_id (str): 目标ID
|
||||
message (mirai.MessageChain): YiriMirai库的消息链
|
||||
message (platform.types.MessageChain): 消息链
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: mirai.MessageEvent,
|
||||
message: mirai.MessageChain,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False
|
||||
):
|
||||
"""回复消息
|
||||
|
||||
Args:
|
||||
message_source (mirai.MessageEvent): YiriMirai消息源事件
|
||||
message (mirai.MessageChain): YiriMirai库的消息链
|
||||
message_source (platform.types.MessageEvent): 消息源事件
|
||||
message (platform.types.MessageChain): 消息链
|
||||
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@ -87,27 +88,27 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
|
|||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
|
||||
event_type: typing.Type[platform_message.Event],
|
||||
callback: typing.Callable[[platform_message.Event, MessageSourceAdapter], None]
|
||||
):
|
||||
"""注册事件监听器
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
event_type (typing.Type[platform.types.Event]): 事件类型
|
||||
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
|
||||
event_type: typing.Type[platform_message.Event],
|
||||
callback: typing.Callable[[platform_message.Event, MessageSourceAdapter], None]
|
||||
):
|
||||
"""注销事件监听器
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
event_type (typing.Type[platform.types.Event]): 事件类型
|
||||
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -127,26 +128,26 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
|
|||
class MessageConverter:
|
||||
"""消息链转换器基类"""
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: mirai.MessageChain):
|
||||
"""将YiriMirai消息链转换为目标消息链
|
||||
def yiri2target(message_chain: platform_message.MessageChain):
|
||||
"""将源平台消息链转换为目标平台消息链
|
||||
|
||||
Args:
|
||||
message_chain (mirai.MessageChain): YiriMirai消息链
|
||||
message_chain (platform.types.MessageChain): 源平台消息链
|
||||
|
||||
Returns:
|
||||
typing.Any: 目标消息链
|
||||
typing.Any: 目标平台消息链
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(message_chain: typing.Any) -> mirai.MessageChain:
|
||||
"""将目标消息链转换为YiriMirai消息链
|
||||
def target2yiri(message_chain: typing.Any) -> platform_message.MessageChain:
|
||||
"""将目标平台消息链转换为源平台消息链
|
||||
|
||||
Args:
|
||||
message_chain (typing.Any): 目标消息链
|
||||
message_chain (typing.Any): 目标平台消息链
|
||||
|
||||
Returns:
|
||||
mirai.MessageChain: YiriMirai消息链
|
||||
platform.types.MessageChain: 源平台消息链
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -155,25 +156,25 @@ class EventConverter:
|
|||
"""事件转换器基类"""
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(event: typing.Type[mirai.Event]):
|
||||
"""将YiriMirai事件转换为目标事件
|
||||
def yiri2target(event: typing.Type[platform_message.Event]):
|
||||
"""将源平台事件转换为目标平台事件
|
||||
|
||||
Args:
|
||||
event (typing.Type[mirai.Event]): YiriMirai事件
|
||||
event (typing.Type[platform.types.Event]): 源平台事件
|
||||
|
||||
Returns:
|
||||
typing.Any: 目标事件
|
||||
typing.Any: 目标平台事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(event: typing.Any) -> mirai.Event:
|
||||
"""将目标事件的调用参数转换为YiriMirai的事件参数对象
|
||||
def target2yiri(event: typing.Any) -> platform_message.Event:
|
||||
"""将目标平台事件的调用参数转换为源平台的事件参数对象
|
||||
|
||||
Args:
|
||||
event (typing.Any): 目标事件
|
||||
event (typing.Any): 目标平台事件
|
||||
|
||||
Returns:
|
||||
typing.Type[mirai.Event]: YiriMirai事件
|
||||
typing.Type[platform.types.Event]: 源平台事件
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -2,17 +2,24 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \
|
||||
FriendMessage, Image, MessageChain, Plain
|
||||
import mirai
|
||||
# FriendMessage, Image, MessageChain, Plain
|
||||
from ..platform import adapter as msadapter
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from ..plugin import events
|
||||
from .types import message as platform_message
|
||||
from .types import events as platform_events
|
||||
from .types import entities as platform_entities
|
||||
|
||||
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题
|
||||
from . import types as mirai
|
||||
sys.modules['mirai'] = mirai
|
||||
|
||||
|
||||
# 控制QQ消息输入输出的类
|
||||
class PlatformManager:
|
||||
|
@ -32,7 +39,7 @@ class PlatformManager:
|
|||
|
||||
from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy
|
||||
|
||||
async def on_friend_message(event: FriendMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PersonMessageReceived(
|
||||
|
@ -55,7 +62,7 @@ class PlatformManager:
|
|||
adapter=adapter
|
||||
)
|
||||
|
||||
async def on_stranger_message(event: StrangerMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
async def on_stranger_message(event: platform_events.StrangerMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PersonMessageReceived(
|
||||
|
@ -78,7 +85,7 @@ class PlatformManager:
|
|||
adapter=adapter
|
||||
)
|
||||
|
||||
async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.GroupMessageReceived(
|
||||
|
@ -127,16 +134,16 @@ class PlatformManager:
|
|||
|
||||
if adapter_name == 'yiri-mirai':
|
||||
adapter_inst.register_listener(
|
||||
StrangerMessage,
|
||||
platform_events.StrangerMessage,
|
||||
on_stranger_message
|
||||
)
|
||||
|
||||
adapter_inst.register_listener(
|
||||
FriendMessage,
|
||||
platform_events.FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
adapter_inst.register_listener(
|
||||
GroupMessage,
|
||||
platform_events.GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
|
||||
|
@ -146,13 +153,13 @@ class PlatformManager:
|
|||
if len(self.adapters) == 0:
|
||||
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
|
||||
|
||||
async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter):
|
||||
async def send(self, event: platform_events.MessageEvent, msg: platform_message.MessageChain, adapter: msadapter.MessageSourceAdapter):
|
||||
|
||||
if self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
|
||||
if self.ap.platform_cfg.data['at-sender'] and isinstance(event, platform_events.GroupMessage):
|
||||
|
||||
msg.insert(
|
||||
0,
|
||||
At(
|
||||
platform_message.At(
|
||||
event.sender.id
|
||||
)
|
||||
)
|
||||
|
|
|
@ -5,31 +5,32 @@ import traceback
|
|||
import time
|
||||
import datetime
|
||||
|
||||
import mirai
|
||||
import mirai.models.message as yiri_message
|
||||
import aiocqhttp
|
||||
|
||||
from .. import adapter
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
|
||||
|
||||
class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: mirai.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
|
||||
def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
|
||||
msg_list = aiocqhttp.Message()
|
||||
|
||||
msg_id = 0
|
||||
msg_time = None
|
||||
|
||||
for msg in message_chain:
|
||||
if type(msg) is mirai.Plain:
|
||||
if type(msg) is platform_message.Plain:
|
||||
msg_list.append(aiocqhttp.MessageSegment.text(msg.text))
|
||||
elif type(msg) is yiri_message.Source:
|
||||
elif type(msg) is platform_message.Source:
|
||||
msg_id = msg.id
|
||||
msg_time = msg.time
|
||||
elif type(msg) is mirai.Image:
|
||||
elif type(msg) is platform_message.Image:
|
||||
arg = ''
|
||||
if msg.base64:
|
||||
arg = msg.base64
|
||||
|
@ -40,13 +41,11 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
|||
elif msg.path:
|
||||
arg = msg.path
|
||||
msg_list.append(aiocqhttp.MessageSegment.image(arg))
|
||||
elif type(msg) is mirai.At:
|
||||
elif type(msg) is platform_message.At:
|
||||
msg_list.append(aiocqhttp.MessageSegment.at(msg.target))
|
||||
elif type(msg) is mirai.AtAll:
|
||||
elif type(msg) is platform_message.AtAll:
|
||||
msg_list.append(aiocqhttp.MessageSegment.at("all"))
|
||||
elif type(msg) is mirai.Face:
|
||||
msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id))
|
||||
elif type(msg) is mirai.Voice:
|
||||
elif type(msg) is platform_message.Voice:
|
||||
arg = ''
|
||||
if msg.base64:
|
||||
arg = msg.base64
|
||||
|
@ -74,25 +73,25 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
|||
yiri_msg_list = []
|
||||
|
||||
yiri_msg_list.append(
|
||||
yiri_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
platform_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
)
|
||||
|
||||
for msg in message:
|
||||
if msg.type == "at":
|
||||
if msg.data["qq"] == "all":
|
||||
yiri_msg_list.append(yiri_message.AtAll())
|
||||
yiri_msg_list.append(platform_message.AtAll())
|
||||
else:
|
||||
yiri_msg_list.append(
|
||||
yiri_message.At(
|
||||
platform_message.At(
|
||||
target=msg.data["qq"],
|
||||
)
|
||||
)
|
||||
elif msg.type == "text":
|
||||
yiri_msg_list.append(yiri_message.Plain(text=msg.data["text"]))
|
||||
yiri_msg_list.append(platform_message.Plain(text=msg.data["text"]))
|
||||
elif msg.type == "image":
|
||||
yiri_msg_list.append(yiri_message.Image(url=msg.data["url"]))
|
||||
yiri_msg_list.append(platform_message.Image(url=msg.data["url"]))
|
||||
|
||||
chain = mirai.MessageChain(yiri_msg_list)
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
return chain
|
||||
|
||||
|
@ -100,11 +99,11 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
|||
class AiocqhttpEventConverter(adapter.EventConverter):
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(event: mirai.Event, bot_account_id: int):
|
||||
def yiri2target(event: platform_events.Event, bot_account_id: int):
|
||||
|
||||
msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain)
|
||||
|
||||
if type(event) is mirai.GroupMessage:
|
||||
if type(event) is platform_events.GroupMessage:
|
||||
role = "member"
|
||||
|
||||
if event.sender.permission == "ADMINISTRATOR":
|
||||
|
@ -140,7 +139,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
|
|||
}
|
||||
|
||||
return aiocqhttp.Event.from_payload(payload)
|
||||
elif type(event) is mirai.FriendMessage:
|
||||
elif type(event) is platform_events.FriendMessage:
|
||||
|
||||
payload = {
|
||||
"post_type": "message",
|
||||
|
@ -178,15 +177,15 @@ class AiocqhttpEventConverter(adapter.EventConverter):
|
|||
permission = "ADMINISTRATOR"
|
||||
elif event.sender["role"] == "owner":
|
||||
permission = "OWNER"
|
||||
converted_event = mirai.GroupMessage(
|
||||
sender=mirai.models.entities.GroupMember(
|
||||
converted_event = platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=event.sender["user_id"], # message_seq 放哪?
|
||||
member_name=event.sender["nickname"],
|
||||
permission=permission,
|
||||
group=mirai.models.entities.Group(
|
||||
group=platform_entities.Group(
|
||||
id=event.group_id,
|
||||
name=event.sender["nickname"],
|
||||
permission=mirai.models.entities.Permission.Member,
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title=event.sender["title"] if "title" in event.sender else "",
|
||||
join_timestamp=0,
|
||||
|
@ -198,8 +197,8 @@ class AiocqhttpEventConverter(adapter.EventConverter):
|
|||
)
|
||||
return converted_event
|
||||
elif event.message_type == "private":
|
||||
return mirai.FriendMessage(
|
||||
sender=mirai.models.entities.Friend(
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.sender["user_id"],
|
||||
nickname=event.sender["nickname"],
|
||||
remark="",
|
||||
|
@ -241,7 +240,7 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
|
|||
self.bot = aiocqhttp.CQHttp()
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: mirai.MessageChain
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
|
||||
|
||||
|
@ -252,8 +251,8 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
|
|||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: mirai.MessageEvent,
|
||||
message: mirai.MessageChain,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
|
||||
|
@ -271,8 +270,8 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
|
|||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None],
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None],
|
||||
):
|
||||
async def on_message(event: aiocqhttp.Event):
|
||||
self.bot_account_id = event.self_id
|
||||
|
@ -281,15 +280,15 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
|
|||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
if event_type == mirai.GroupMessage:
|
||||
if event_type == platform_events.GroupMessage:
|
||||
self.bot.on_message("group")(on_message)
|
||||
elif event_type == mirai.FriendMessage:
|
||||
elif event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message("private")(on_message)
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None],
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
|
|
|
@ -6,26 +6,28 @@ import typing
|
|||
import traceback
|
||||
import logging
|
||||
|
||||
import mirai
|
||||
|
||||
import nakuru
|
||||
import nakuru.entities.components as nkc
|
||||
|
||||
from .. import adapter as adapter_model
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...platform.types import message as platform_message
|
||||
from ...platform.types import entities as platform_entities
|
||||
from ...platform.types import events as platform_events
|
||||
|
||||
|
||||
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
"""消息转换器"""
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: mirai.MessageChain) -> list:
|
||||
def yiri2target(message_chain: platform_message.MessageChain) -> list:
|
||||
msg_list = []
|
||||
if type(message_chain) is mirai.MessageChain:
|
||||
if type(message_chain) is platform_message.MessageChain:
|
||||
msg_list = message_chain.__root__
|
||||
elif type(message_chain) is list:
|
||||
msg_list = message_chain
|
||||
elif type(message_chain) is str:
|
||||
msg_list = [mirai.Plain(message_chain)]
|
||||
msg_list = [platform_message.Plain(message_chain)]
|
||||
else:
|
||||
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
|
||||
|
||||
|
@ -33,22 +35,20 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
|||
|
||||
# 遍历并转换
|
||||
for component in msg_list:
|
||||
if type(component) is mirai.Plain:
|
||||
if type(component) is platform_message.Plain:
|
||||
nakuru_msg_list.append(nkc.Plain(component.text, False))
|
||||
elif type(component) is mirai.Image:
|
||||
elif type(component) is platform_message.Image:
|
||||
if component.url is not None:
|
||||
nakuru_msg_list.append(nkc.Image.fromURL(component.url))
|
||||
elif component.base64 is not None:
|
||||
nakuru_msg_list.append(nkc.Image.fromBase64(component.base64))
|
||||
elif component.path is not None:
|
||||
nakuru_msg_list.append(nkc.Image.fromFileSystem(component.path))
|
||||
elif type(component) is mirai.Face:
|
||||
nakuru_msg_list.append(nkc.Face(id=component.face_id))
|
||||
elif type(component) is mirai.At:
|
||||
elif type(component) is platform_message.At:
|
||||
nakuru_msg_list.append(nkc.At(qq=component.target))
|
||||
elif type(component) is mirai.AtAll:
|
||||
elif type(component) is platform_message.AtAll:
|
||||
nakuru_msg_list.append(nkc.AtAll())
|
||||
elif type(component) is mirai.Voice:
|
||||
elif type(component) is platform_message.Voice:
|
||||
if component.url is not None:
|
||||
nakuru_msg_list.append(nkc.Record.fromURL(component.url))
|
||||
elif component.path is not None:
|
||||
|
@ -80,49 +80,47 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
|||
return nakuru_msg_list
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> mirai.MessageChain:
|
||||
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> platform_message.MessageChain:
|
||||
"""将Yiri的消息链转换为YiriMirai的消息链"""
|
||||
assert type(message_chain) is list
|
||||
|
||||
yiri_msg_list = []
|
||||
import datetime
|
||||
# 添加Source组件以标记message_id等信息
|
||||
yiri_msg_list.append(mirai.models.message.Source(id=message_id, time=datetime.datetime.now()))
|
||||
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
|
||||
for component in message_chain:
|
||||
if type(component) is nkc.Plain:
|
||||
yiri_msg_list.append(mirai.Plain(text=component.text))
|
||||
yiri_msg_list.append(platform_message.Plain(text=component.text))
|
||||
elif type(component) is nkc.Image:
|
||||
yiri_msg_list.append(mirai.Image(url=component.url))
|
||||
elif type(component) is nkc.Face:
|
||||
yiri_msg_list.append(mirai.Face(face_id=component.id))
|
||||
yiri_msg_list.append(platform_message.Image(url=component.url))
|
||||
elif type(component) is nkc.At:
|
||||
yiri_msg_list.append(mirai.At(target=component.qq))
|
||||
yiri_msg_list.append(platform_message.At(target=component.qq))
|
||||
elif type(component) is nkc.AtAll:
|
||||
yiri_msg_list.append(mirai.AtAll())
|
||||
yiri_msg_list.append(platform_message.AtAll())
|
||||
else:
|
||||
pass
|
||||
# logging.debug("转换后的消息链: " + str(yiri_msg_list))
|
||||
chain = mirai.MessageChain(yiri_msg_list)
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
return chain
|
||||
|
||||
|
||||
class NakuruProjectEventConverter(adapter_model.EventConverter):
|
||||
"""事件转换器"""
|
||||
@staticmethod
|
||||
def yiri2target(event: typing.Type[mirai.Event]):
|
||||
if event is mirai.GroupMessage:
|
||||
def yiri2target(event: typing.Type[platform_events.Event]):
|
||||
if event is platform_events.GroupMessage:
|
||||
return nakuru.GroupMessage
|
||||
elif event is mirai.FriendMessage:
|
||||
elif event is platform_events.FriendMessage:
|
||||
return nakuru.FriendMessage
|
||||
else:
|
||||
raise Exception("未支持转换的事件类型: " + str(event))
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(event: typing.Any) -> mirai.Event:
|
||||
def target2yiri(event: typing.Any) -> platform_events.Event:
|
||||
yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id)
|
||||
if type(event) is nakuru.FriendMessage: # 私聊消息事件
|
||||
return mirai.FriendMessage(
|
||||
sender=mirai.models.entities.Friend(
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.sender.user_id,
|
||||
nickname=event.sender.nickname,
|
||||
remark=event.sender.nickname
|
||||
|
@ -138,16 +136,15 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
|
|||
elif event.sender.role == "owner":
|
||||
permission = "OWNER"
|
||||
|
||||
import mirai.models.entities as entities
|
||||
return mirai.GroupMessage(
|
||||
sender=mirai.models.entities.GroupMember(
|
||||
return platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=event.sender.user_id,
|
||||
member_name=event.sender.nickname,
|
||||
permission=permission,
|
||||
group=mirai.models.entities.Group(
|
||||
group=platform_entities.Group(
|
||||
id=event.group_id,
|
||||
name=event.sender.nickname,
|
||||
permission=entities.Permission.Member
|
||||
permission=platform_entities.Permission.Member
|
||||
),
|
||||
special_title=event.sender.title,
|
||||
join_timestamp=0,
|
||||
|
@ -189,7 +186,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
|||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: typing.Union[mirai.MessageChain, list],
|
||||
message: typing.Union[platform_message.MessageChain, list],
|
||||
converted: bool = False
|
||||
):
|
||||
task = None
|
||||
|
@ -222,8 +219,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
|||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: mirai.MessageEvent,
|
||||
message: mirai.MessageChain,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False
|
||||
):
|
||||
message = self.message_converter.yiri2target(message)
|
||||
|
@ -233,14 +230,14 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
|||
id=message_source.message_chain.message_id,
|
||||
)
|
||||
)
|
||||
if type(message_source) is mirai.GroupMessage:
|
||||
if type(message_source) is platform_events.GroupMessage:
|
||||
await self.send_message(
|
||||
"group",
|
||||
message_source.sender.group.id,
|
||||
message,
|
||||
converted=True
|
||||
)
|
||||
elif type(message_source) is mirai.FriendMessage:
|
||||
elif type(message_source) is platform_events.FriendMessage:
|
||||
await self.send_message(
|
||||
"person",
|
||||
message_source.sender.id,
|
||||
|
@ -258,8 +255,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
|||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessageSourceAdapter], None]
|
||||
):
|
||||
try:
|
||||
|
||||
|
@ -286,8 +283,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
|||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessageSourceAdapter], None]
|
||||
):
|
||||
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import datetime
|
|||
import re
|
||||
import traceback
|
||||
|
||||
import mirai
|
||||
import botpy
|
||||
import botpy.message as botpy_message
|
||||
import botpy.types.message as botpy_message_type
|
||||
|
@ -17,17 +16,21 @@ from .. import adapter as adapter_model
|
|||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ...config import manager as cfg_mgr
|
||||
from ...platform.types import entities as platform_entities
|
||||
from ...platform.types import events as platform_events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
|
||||
class OfficialGroupMessage(mirai.GroupMessage):
|
||||
|
||||
class OfficialGroupMessage(platform_events.GroupMessage):
|
||||
pass
|
||||
|
||||
class OfficialFriendMessage(mirai.FriendMessage):
|
||||
class OfficialFriendMessage(platform_events.FriendMessage):
|
||||
pass
|
||||
|
||||
event_handler_mapping = {
|
||||
mirai.GroupMessage: ["on_at_message_create", "on_group_at_message_create"],
|
||||
mirai.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"],
|
||||
platform_events.GroupMessage: ["on_at_message_create", "on_group_at_message_create"],
|
||||
platform_events.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"],
|
||||
}
|
||||
|
||||
|
||||
|
@ -123,16 +126,16 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
|||
"""QQ 官方消息转换器"""
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: mirai.MessageChain):
|
||||
def yiri2target(message_chain: platform_message.MessageChain):
|
||||
"""将 YiriMirai 的消息链转换为 QQ 官方消息"""
|
||||
|
||||
msg_list = []
|
||||
if type(message_chain) is mirai.MessageChain:
|
||||
if type(message_chain) is platform_message.MessageChain:
|
||||
msg_list = message_chain.__root__
|
||||
elif type(message_chain) is list:
|
||||
msg_list = message_chain
|
||||
elif type(message_chain) is str:
|
||||
msg_list = [mirai.Plain(text=message_chain)]
|
||||
msg_list = [platform_message.Plain(text=message_chain)]
|
||||
else:
|
||||
raise Exception(
|
||||
"Unknown message type: " + str(message_chain) + str(type(message_chain))
|
||||
|
@ -153,22 +156,22 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
|||
|
||||
# 遍历并转换
|
||||
for component in msg_list:
|
||||
if type(component) is mirai.Plain:
|
||||
if type(component) is platform_message.Plain:
|
||||
offcial_messages.append({"type": "text", "content": component.text})
|
||||
elif type(component) is mirai.Image:
|
||||
elif type(component) is platform_message.Image:
|
||||
if component.url is not None:
|
||||
offcial_messages.append({"type": "image", "content": component.url})
|
||||
elif component.path is not None:
|
||||
offcial_messages.append(
|
||||
{"type": "file_image", "content": component.path}
|
||||
)
|
||||
elif type(component) is mirai.At:
|
||||
elif type(component) is platform_message.At:
|
||||
offcial_messages.append({"type": "at", "content": ""})
|
||||
elif type(component) is mirai.AtAll:
|
||||
elif type(component) is platform_message.AtAll:
|
||||
print(
|
||||
"上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
|
||||
)
|
||||
elif type(component) is mirai.Voice:
|
||||
elif type(component) is platform_message.Voice:
|
||||
print(
|
||||
"上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
|
||||
)
|
||||
|
@ -197,29 +200,29 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
|||
message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
|
||||
message_id: str = None,
|
||||
bot_account_id: int = 0,
|
||||
) -> mirai.MessageChain:
|
||||
) -> platform_message.MessageChain:
|
||||
yiri_msg_list = []
|
||||
# 存id
|
||||
|
||||
yiri_msg_list.append(
|
||||
mirai.models.message.Source(
|
||||
platform_message.Source(
|
||||
id=save_msg_id(message_id), time=datetime.datetime.now()
|
||||
)
|
||||
)
|
||||
|
||||
if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]:
|
||||
yiri_msg_list.append(mirai.At(target=bot_account_id))
|
||||
yiri_msg_list.append(platform_message.At(target=bot_account_id))
|
||||
|
||||
if hasattr(message, "mentions"):
|
||||
for mention in message.mentions:
|
||||
if mention.bot:
|
||||
continue
|
||||
|
||||
yiri_msg_list.append(mirai.At(target=mention.id))
|
||||
yiri_msg_list.append(platform_message.At(target=mention.id))
|
||||
|
||||
for attachment in message.attachments:
|
||||
if attachment.content_type.startswith("image"):
|
||||
yiri_msg_list.append(mirai.Image(url=attachment.url))
|
||||
yiri_msg_list.append(platform_message.Image(url=attachment.url))
|
||||
else:
|
||||
logging.warning(
|
||||
"不支持的附件类型:" + attachment.content_type + ",忽略此附件。"
|
||||
|
@ -227,9 +230,9 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
|||
|
||||
content = re.sub(r"<@!\d+>", "", str(message.content))
|
||||
if content.strip() != "":
|
||||
yiri_msg_list.append(mirai.Plain(text=content))
|
||||
yiri_msg_list.append(platform_message.Plain(text=content))
|
||||
|
||||
chain = mirai.MessageChain(yiri_msg_list)
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
return chain
|
||||
|
||||
|
@ -244,10 +247,10 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
|||
self.member_openid_mapping = member_openid_mapping
|
||||
self.group_openid_mapping = group_openid_mapping
|
||||
|
||||
def yiri2target(self, event: typing.Type[mirai.Event]):
|
||||
if event == mirai.GroupMessage:
|
||||
def yiri2target(self, event: typing.Type[platform_events.Event]):
|
||||
if event == platform_events.GroupMessage:
|
||||
return botpy_message.Message
|
||||
elif event == mirai.FriendMessage:
|
||||
elif event == platform_events.FriendMessage:
|
||||
return botpy_message.DirectMessage
|
||||
else:
|
||||
raise Exception(
|
||||
|
@ -257,8 +260,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
|||
def target2yiri(
|
||||
self,
|
||||
event: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
|
||||
) -> mirai.Event:
|
||||
import mirai.models.entities as mirai_entities
|
||||
) -> platform_events.Event:
|
||||
|
||||
if type(event) == botpy_message.Message: # 频道内,转群聊事件
|
||||
permission = "MEMBER"
|
||||
|
@ -268,15 +270,15 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
|||
elif "4" in event.member.roles:
|
||||
permission = "OWNER"
|
||||
|
||||
return mirai.GroupMessage(
|
||||
sender=mirai_entities.GroupMember(
|
||||
return platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=event.author.id,
|
||||
member_name=event.author.username,
|
||||
permission=permission,
|
||||
group=mirai_entities.Group(
|
||||
group=platform_entities.Group(
|
||||
id=event.channel_id,
|
||||
name=event.author.username,
|
||||
permission=mirai_entities.Permission.Member,
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title="",
|
||||
join_timestamp=int(
|
||||
|
@ -297,8 +299,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
|||
),
|
||||
)
|
||||
elif type(event) == botpy_message.DirectMessage: # 频道私聊,转私聊事件
|
||||
return mirai.FriendMessage(
|
||||
sender=mirai_entities.Friend(
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.guild_id,
|
||||
nickname=event.author.username,
|
||||
remark=event.author.username,
|
||||
|
@ -317,14 +319,14 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
|||
replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid)
|
||||
|
||||
return OfficialGroupMessage(
|
||||
sender=mirai_entities.GroupMember(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=replacing_member_id,
|
||||
member_name=replacing_member_id,
|
||||
permission="MEMBER",
|
||||
group=mirai_entities.Group(
|
||||
group=platform_entities.Group(
|
||||
id=self.group_openid_mapping.save_openid(event.group_openid),
|
||||
name=replacing_member_id,
|
||||
permission=mirai_entities.Permission.Member,
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title="",
|
||||
join_timestamp=int(0),
|
||||
|
@ -345,7 +347,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
|||
user_id_alter = self.member_openid_mapping.save_openid(event.author.user_openid) # 实测这里的user_openid与group的member_openid是一样的
|
||||
|
||||
return OfficialFriendMessage(
|
||||
sender=mirai_entities.Friend(
|
||||
sender=platform_entities.Friend(
|
||||
id=user_id_alter,
|
||||
nickname=user_id_alter,
|
||||
remark=user_id_alter,
|
||||
|
@ -410,7 +412,7 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
|||
self.bot = botpy.Client(intents=intents)
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: mirai.MessageChain
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
message_list = self.message_converter.yiri2target(message)
|
||||
|
||||
|
@ -437,8 +439,8 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
|||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: mirai.MessageEvent,
|
||||
message: mirai.MessageChain,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
|
||||
|
@ -463,13 +465,13 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
|||
]
|
||||
)
|
||||
|
||||
if type(message_source) == mirai.GroupMessage:
|
||||
if type(message_source) == platform_events.GroupMessage:
|
||||
args["channel_id"] = str(message_source.sender.group.id)
|
||||
args["msg_id"] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
await self.bot.api.post_message(**args)
|
||||
elif type(message_source) == mirai.FriendMessage:
|
||||
elif type(message_source) == platform_events.FriendMessage:
|
||||
args["guild_id"] = str(message_source.sender.id)
|
||||
args["msg_id"] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
|
@ -534,9 +536,9 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
|||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[mirai.Event, adapter_model.MessageSourceAdapter], None
|
||||
[platform_events.Event, adapter_model.MessageSourceAdapter], None
|
||||
],
|
||||
):
|
||||
|
||||
|
@ -560,9 +562,9 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
|||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[mirai.Event, adapter_model.MessageSourceAdapter], None
|
||||
[platform_events.Event, adapter_model.MessageSourceAdapter], None
|
||||
],
|
||||
):
|
||||
delattr(self.bot, event_handler_mapping[event_type])
|
||||
|
|
|
@ -1,124 +1,121 @@
|
|||
import asyncio
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
import mirai.models.bus
|
||||
from mirai.bot import MiraiRunner
|
||||
|
||||
from .. import adapter as adapter_model
|
||||
from ...core import app
|
||||
# import asyncio
|
||||
# import typing
|
||||
|
||||
|
||||
@adapter_model.adapter_class("yiri-mirai")
|
||||
class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
|
||||
"""YiriMirai适配器"""
|
||||
bot: mirai.Mirai
|
||||
# from .. import adapter as adapter_model
|
||||
# from ...core import app
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application):
|
||||
"""初始化YiriMirai的对象"""
|
||||
self.ap = ap
|
||||
self.config = config
|
||||
if 'adapter' not in config or \
|
||||
config['adapter'] == 'WebSocketAdapter':
|
||||
self.bot = mirai.Mirai(
|
||||
qq=config['qq'],
|
||||
adapter=mirai.WebSocketAdapter(
|
||||
host=config['host'],
|
||||
port=config['port'],
|
||||
verify_key=config['verifyKey']
|
||||
)
|
||||
)
|
||||
elif config['adapter'] == 'HTTPAdapter':
|
||||
self.bot = mirai.Mirai(
|
||||
qq=config['qq'],
|
||||
adapter=mirai.HTTPAdapter(
|
||||
host=config['host'],
|
||||
port=config['port'],
|
||||
verify_key=config['verifyKey']
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise Exception('Unknown adapter for YiriMirai: ' + config['adapter'])
|
||||
|
||||
# @adapter_model.adapter_class("yiri-mirai")
|
||||
# class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
|
||||
# """YiriMirai适配器"""
|
||||
# bot: mirai.Mirai
|
||||
|
||||
# def __init__(self, config: dict, ap: app.Application):
|
||||
# """初始化YiriMirai的对象"""
|
||||
# self.ap = ap
|
||||
# self.config = config
|
||||
# if 'adapter' not in config or \
|
||||
# config['adapter'] == 'WebSocketAdapter':
|
||||
# self.bot = mirai.Mirai(
|
||||
# qq=config['qq'],
|
||||
# adapter=mirai.WebSocketAdapter(
|
||||
# host=config['host'],
|
||||
# port=config['port'],
|
||||
# verify_key=config['verifyKey']
|
||||
# )
|
||||
# )
|
||||
# elif config['adapter'] == 'HTTPAdapter':
|
||||
# self.bot = mirai.Mirai(
|
||||
# qq=config['qq'],
|
||||
# adapter=mirai.HTTPAdapter(
|
||||
# host=config['host'],
|
||||
# port=config['port'],
|
||||
# verify_key=config['verifyKey']
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# raise Exception('Unknown adapter for YiriMirai: ' + config['adapter'])
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: mirai.MessageChain
|
||||
):
|
||||
"""发送消息
|
||||
# async def send_message(
|
||||
# self,
|
||||
# target_type: str,
|
||||
# target_id: str,
|
||||
# message: mirai.MessageChain
|
||||
# ):
|
||||
# """发送消息
|
||||
|
||||
Args:
|
||||
target_type (str): 目标类型,`person`或`group`
|
||||
target_id (str): 目标ID
|
||||
message (mirai.MessageChain): YiriMirai库的消息链
|
||||
"""
|
||||
task = None
|
||||
if target_type == 'person':
|
||||
task = self.bot.send_friend_message(int(target_id), message)
|
||||
elif target_type == 'group':
|
||||
task = self.bot.send_group_message(int(target_id), message)
|
||||
else:
|
||||
raise Exception('Unknown target type: ' + target_type)
|
||||
# Args:
|
||||
# target_type (str): 目标类型,`person`或`group`
|
||||
# target_id (str): 目标ID
|
||||
# message (mirai.MessageChain): YiriMirai库的消息链
|
||||
# """
|
||||
# task = None
|
||||
# if target_type == 'person':
|
||||
# task = self.bot.send_friend_message(int(target_id), message)
|
||||
# elif target_type == 'group':
|
||||
# task = self.bot.send_group_message(int(target_id), message)
|
||||
# else:
|
||||
# raise Exception('Unknown target type: ' + target_type)
|
||||
|
||||
await task
|
||||
# await task
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: mirai.MessageEvent,
|
||||
message: mirai.MessageChain,
|
||||
quote_origin: bool = False
|
||||
):
|
||||
"""回复消息
|
||||
# async def reply_message(
|
||||
# self,
|
||||
# message_source: mirai.MessageEvent,
|
||||
# message: mirai.MessageChain,
|
||||
# quote_origin: bool = False
|
||||
# ):
|
||||
# """回复消息
|
||||
|
||||
Args:
|
||||
message_source (mirai.MessageEvent): YiriMirai消息源事件
|
||||
message (mirai.MessageChain): YiriMirai库的消息链
|
||||
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
|
||||
"""
|
||||
await self.bot.send(message_source, message, quote_origin)
|
||||
# Args:
|
||||
# message_source (mirai.MessageEvent): YiriMirai消息源事件
|
||||
# message (mirai.MessageChain): YiriMirai库的消息链
|
||||
# quote_origin (bool, optional): 是否引用原消息. Defaults to False.
|
||||
# """
|
||||
# await self.bot.send(message_source, message, quote_origin)
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get()
|
||||
if result.mute_time_remaining > 0:
|
||||
return True
|
||||
return False
|
||||
# async def is_muted(self, group_id: int) -> bool:
|
||||
# result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get()
|
||||
# if result.mute_time_remaining > 0:
|
||||
# return True
|
||||
# return False
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
):
|
||||
"""注册事件监听器
|
||||
# def register_listener(
|
||||
# self,
|
||||
# event_type: typing.Type[mirai.Event],
|
||||
# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
# ):
|
||||
# """注册事件监听器
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
"""
|
||||
async def wrapper(event: mirai.Event):
|
||||
await callback(event, self)
|
||||
self.bot.on(event_type)(wrapper)
|
||||
# Args:
|
||||
# event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
# callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
# """
|
||||
# async def wrapper(event: mirai.Event):
|
||||
# await callback(event, self)
|
||||
# self.bot.on(event_type)(wrapper)
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
):
|
||||
"""注销事件监听器
|
||||
# def unregister_listener(
|
||||
# self,
|
||||
# event_type: typing.Type[mirai.Event],
|
||||
# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
# ):
|
||||
# """注销事件监听器
|
||||
|
||||
Args:
|
||||
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
"""
|
||||
assert isinstance(self.bot, mirai.Mirai)
|
||||
bus = self.bot.bus
|
||||
assert isinstance(bus, mirai.models.bus.ModelEventBus)
|
||||
# Args:
|
||||
# event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
# callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
# """
|
||||
# assert isinstance(self.bot, mirai.Mirai)
|
||||
# bus = self.bot.bus
|
||||
# assert isinstance(bus, mirai.models.bus.ModelEventBus)
|
||||
|
||||
bus.unsubscribe(event_type, callback)
|
||||
# bus.unsubscribe(event_type, callback)
|
||||
|
||||
async def run_async(self):
|
||||
self.bot_account_id = self.bot.qq
|
||||
return await MiraiRunner(self.bot)._run()
|
||||
# async def run_async(self):
|
||||
# self.bot_account_id = self.bot.qq
|
||||
# return await MiraiRunner(self.bot)._run()
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
# async def kill(self) -> bool:
|
||||
# return False
|
||||
|
|
3
pkg/platform/types/__init__.py
Normal file
3
pkg/platform/types/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .entities import *
|
||||
from .events import *
|
||||
from .message import *
|
105
pkg/platform/types/base.py
Normal file
105
pkg/platform/types/base.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
|
||||
from typing import Dict, List, Type
|
||||
|
||||
import pydantic.main as pdm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PlatformMetaclass(pdm.ModelMetaclass):
|
||||
"""此类是平台中使用的 pydantic 模型的元类的基类。"""
|
||||
|
||||
|
||||
def to_camel(name: str) -> str:
|
||||
"""将下划线命名风格转换为小驼峰命名。"""
|
||||
if name[:2] == '__': # 不处理双下划线开头的特殊命名。
|
||||
return name
|
||||
name_parts = name.split('_')
|
||||
return ''.join(name_parts[:1] + [x.title() for x in name_parts[1:]])
|
||||
|
||||
|
||||
class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass):
|
||||
"""模型基类。
|
||||
|
||||
启用了三项配置:
|
||||
1. 允许解析时传入额外的值,并将额外值保存在模型中。
|
||||
2. 允许通过别名访问字段。
|
||||
3. 自动生成小驼峰风格的别名。
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
""""""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__ + '(' + ', '.join(
|
||||
(f'{k}={repr(v)}' for k, v in self.__dict__.items() if v)
|
||||
) + ')'
|
||||
|
||||
class Config:
|
||||
extra = 'allow'
|
||||
allow_population_by_field_name = True
|
||||
alias_generator = to_camel
|
||||
|
||||
|
||||
class PlatformIndexedMetaclass(PlatformMetaclass):
|
||||
"""可以通过子类名获取子类的类的元类。"""
|
||||
__indexedbases__: List[Type['PlatformIndexedModel']] = []
|
||||
__indexedmodel__ = None
|
||||
|
||||
def __new__(cls, name, bases, attrs, **kwargs):
|
||||
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||
# 第一类:PlatformIndexedModel
|
||||
if name == 'PlatformIndexedModel':
|
||||
cls.__indexedmodel__ = new_cls
|
||||
new_cls.__indexes__ = {}
|
||||
return new_cls
|
||||
# 第二类:PlatformIndexedModel 的直接子类,这些是可以通过子类名获取子类的类。
|
||||
if cls.__indexedmodel__ in bases:
|
||||
cls.__indexedbases__.append(new_cls)
|
||||
new_cls.__indexes__ = {}
|
||||
return new_cls
|
||||
# 第三类:PlatformIndexedModel 的直接子类的子类,这些添加到直接子类的索引中。
|
||||
for base in cls.__indexedbases__:
|
||||
if issubclass(new_cls, base):
|
||||
base.__indexes__[name] = new_cls
|
||||
return new_cls
|
||||
|
||||
def __getitem__(cls, name):
|
||||
return cls.get_subtype(name)
|
||||
|
||||
|
||||
class PlatformIndexedModel(PlatformBaseModel, metaclass=PlatformIndexedMetaclass):
|
||||
"""可以通过子类名获取子类的类。"""
|
||||
__indexes__: Dict[str, Type['PlatformIndexedModel']]
|
||||
|
||||
@classmethod
|
||||
def get_subtype(cls, name: str) -> Type['PlatformIndexedModel']:
|
||||
"""根据类名称,获取相应的子类类型。
|
||||
|
||||
Args:
|
||||
name: 类名称。
|
||||
|
||||
Returns:
|
||||
Type['PlatformIndexedModel']: 子类类型。
|
||||
"""
|
||||
try:
|
||||
type_ = cls.__indexes__.get(name)
|
||||
if not (type_ and issubclass(type_, cls)):
|
||||
raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!')
|
||||
return type_
|
||||
except AttributeError as e:
|
||||
raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!') from None
|
||||
|
||||
@classmethod
|
||||
def parse_subtype(cls, obj: dict) -> 'PlatformIndexedModel':
|
||||
"""通过字典,构造对应的模型对象。
|
||||
|
||||
Args:
|
||||
obj: 一个字典,包含了模型对象的属性。
|
||||
|
||||
Returns:
|
||||
PlatformIndexedModel: 构造的对象。
|
||||
"""
|
||||
if cls in PlatformIndexedModel.__subclasses__():
|
||||
ModelType = cls.get_subtype(obj['type'])
|
||||
return ModelType.parse_obj(obj)
|
||||
return super().parse_obj(obj)
|
143
pkg/platform/types/entities.py
Normal file
143
pkg/platform/types/entities.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
此模块提供实体和配置项模型。
|
||||
"""
|
||||
import abc
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class Entity(pydantic.BaseModel):
|
||||
"""实体,表示一个用户或群。"""
|
||||
id: int
|
||||
"""QQ 号或群号。"""
|
||||
@abc.abstractmethod
|
||||
def get_avatar_url(self) -> str:
|
||||
"""头像图片链接。"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""名称。"""
|
||||
|
||||
|
||||
class Friend(Entity):
|
||||
"""好友。"""
|
||||
id: int
|
||||
"""QQ 号。"""
|
||||
nickname: typing.Optional[str]
|
||||
"""昵称。"""
|
||||
remark: typing.Optional[str]
|
||||
"""备注。"""
|
||||
def get_avatar_url(self) -> str:
|
||||
return f'http://q4.qlogo.cn/g?b=qq&nk={self.id}&s=140'
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.nickname or self.remark or ''
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
"""群成员身份权限。"""
|
||||
Member = "MEMBER"
|
||||
"""成员。"""
|
||||
Administrator = "ADMINISTRATOR"
|
||||
"""管理员。"""
|
||||
Owner = "OWNER"
|
||||
"""群主。"""
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.value)
|
||||
|
||||
|
||||
class Group(Entity):
|
||||
"""群。"""
|
||||
id: int
|
||||
"""群号。"""
|
||||
name: str
|
||||
"""群名称。"""
|
||||
permission: Permission
|
||||
"""Bot 在群中的权限。"""
|
||||
def get_avatar_url(self) -> str:
|
||||
return f'https://p.qlogo.cn/gh/{self.id}/{self.id}/'
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class GroupMember(Entity):
|
||||
"""群成员。"""
|
||||
id: int
|
||||
"""QQ 号。"""
|
||||
member_name: str
|
||||
"""群成员名称。"""
|
||||
permission: Permission
|
||||
"""Bot 在群中的权限。"""
|
||||
group: Group
|
||||
"""群。"""
|
||||
special_title: str = ''
|
||||
"""群头衔。"""
|
||||
join_timestamp: datetime = datetime.utcfromtimestamp(0)
|
||||
"""加入群的时间。"""
|
||||
last_speak_timestamp: datetime = datetime.utcfromtimestamp(0)
|
||||
"""最后一次发言的时间。"""
|
||||
mute_time_remaining: int = 0
|
||||
"""禁言剩余时间。"""
|
||||
def get_avatar_url(self) -> str:
|
||||
return f'http://q4.qlogo.cn/g?b=qq&nk={self.id}&s=140'
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.member_name
|
||||
|
||||
|
||||
class Client(Entity):
|
||||
"""来自其他客户端的用户。"""
|
||||
id: int
|
||||
"""识别 id。"""
|
||||
platform: str
|
||||
"""来源平台。"""
|
||||
def get_avatar_url(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self.platform
|
||||
|
||||
|
||||
class Subject(pydantic.BaseModel):
|
||||
"""另一种实体类型表示。"""
|
||||
id: int
|
||||
"""QQ 号或群号。"""
|
||||
kind: typing.Literal['Friend', 'Group', 'Stranger']
|
||||
"""类型。"""
|
||||
|
||||
|
||||
class Config(pydantic.BaseModel):
|
||||
"""配置项类型。"""
|
||||
def modify(self, **kwargs) -> 'Config':
|
||||
"""修改部分设置。"""
|
||||
for k, v in kwargs.items():
|
||||
if k in self.__fields__:
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
raise ValueError(f'未知配置项: {k}')
|
||||
return self
|
||||
|
||||
|
||||
class GroupConfigModel(Config):
|
||||
"""群配置。"""
|
||||
name: str
|
||||
"""群名称。"""
|
||||
confess_talk: bool
|
||||
"""是否允许坦白说。"""
|
||||
allow_member_invite: bool
|
||||
"""是否允许成员邀请好友入群。"""
|
||||
auto_approve: bool
|
||||
"""是否开启自动审批入群。"""
|
||||
anonymous_chat: bool
|
||||
"""是否开启匿名聊天。"""
|
||||
announcement: str = ''
|
||||
"""群公告。"""
|
||||
|
||||
|
||||
class MemberInfoModel(Config, GroupMember):
|
||||
"""群成员信息。"""
|
124
pkg/platform/types/events.py
Normal file
124
pkg/platform/types/events.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
此模块提供事件模型。
|
||||
"""
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
from . import entities as platform_entities
|
||||
from . import message as platform_message
|
||||
|
||||
|
||||
class Event(pydantic.BaseModel):
|
||||
"""事件基类。
|
||||
|
||||
Args:
|
||||
type: 事件名。
|
||||
"""
|
||||
type: str
|
||||
"""事件名。"""
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + ', '.join(
|
||||
(
|
||||
f'{k}={repr(v)}'
|
||||
for k, v in self.__dict__.items() if k != 'type' and v
|
||||
)
|
||||
) + ')'
|
||||
|
||||
@classmethod
|
||||
def parse_subtype(cls, obj: dict) -> 'Event':
|
||||
try:
|
||||
return typing.cast(Event, super().parse_subtype(obj))
|
||||
except ValueError:
|
||||
return Event(type=obj['type'])
|
||||
|
||||
@classmethod
|
||||
def get_subtype(cls, name: str) -> typing.Type['Event']:
|
||||
try:
|
||||
return typing.cast(typing.Type[Event], super().get_subtype(name))
|
||||
except ValueError:
|
||||
return Event
|
||||
|
||||
|
||||
###############################
|
||||
# Bot Event
|
||||
class BotEvent(Event):
|
||||
"""Bot 自身事件。
|
||||
|
||||
Args:
|
||||
type: 事件名。
|
||||
qq: Bot 的 QQ 号。
|
||||
"""
|
||||
type: str
|
||||
"""事件名。"""
|
||||
qq: int
|
||||
"""Bot 的 QQ 号。"""
|
||||
|
||||
|
||||
###############################
|
||||
# Message Event
|
||||
class MessageEvent(Event):
|
||||
"""消息事件。
|
||||
|
||||
Args:
|
||||
type: 事件名。
|
||||
message_chain: 消息内容。
|
||||
"""
|
||||
type: str
|
||||
"""事件名。"""
|
||||
message_chain: platform_message.MessageChain
|
||||
"""消息内容。"""
|
||||
|
||||
|
||||
class FriendMessage(MessageEvent):
|
||||
"""好友消息。
|
||||
|
||||
Args:
|
||||
type: 事件名。
|
||||
sender: 发送消息的好友。
|
||||
message_chain: 消息内容。
|
||||
"""
|
||||
type: str = 'FriendMessage'
|
||||
"""事件名。"""
|
||||
sender: platform_entities.Friend
|
||||
"""发送消息的好友。"""
|
||||
message_chain: platform_message.MessageChain
|
||||
"""消息内容。"""
|
||||
|
||||
|
||||
class GroupMessage(MessageEvent):
|
||||
"""群消息。
|
||||
|
||||
Args:
|
||||
type: 事件名。
|
||||
sender: 发送消息的群成员。
|
||||
message_chain: 消息内容。
|
||||
"""
|
||||
type: str = 'GroupMessage'
|
||||
"""事件名。"""
|
||||
sender: platform_entities.GroupMember
|
||||
"""发送消息的群成员。"""
|
||||
message_chain: platform_message.MessageChain
|
||||
"""消息内容。"""
|
||||
@property
|
||||
def group(self) -> platform_entities.Group:
|
||||
return self.sender.group
|
||||
|
||||
|
||||
class StrangerMessage(MessageEvent):
|
||||
"""陌生人消息。
|
||||
|
||||
Args:
|
||||
type: 事件名。
|
||||
sender: 发送消息的人。
|
||||
message_chain: 消息内容。
|
||||
"""
|
||||
type: str = 'StrangerMessage'
|
||||
"""事件名。"""
|
||||
sender: platform_entities.Friend
|
||||
"""发送消息的人。"""
|
||||
message_chain: platform_message.MessageChain
|
||||
"""消息内容。"""
|
817
pkg/platform/types/message.py
Normal file
817
pkg/platform/types/message.py
Normal file
|
@ -0,0 +1,817 @@
|
|||
import itertools
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
import pydantic.main
|
||||
|
||||
from . import entities as platform_entities
|
||||
from .base import PlatformBaseModel, PlatformIndexedMetaclass, PlatformIndexedModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageComponentMetaclass(PlatformIndexedMetaclass):
|
||||
"""消息组件元类。"""
|
||||
__message_component__ = None
|
||||
|
||||
def __new__(cls, name, bases, attrs, **kwargs):
|
||||
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||
if name == 'MessageComponent':
|
||||
cls.__message_component__ = new_cls
|
||||
|
||||
if not cls.__message_component__:
|
||||
return new_cls
|
||||
|
||||
for base in bases:
|
||||
if issubclass(base, cls.__message_component__):
|
||||
# 获取字段名
|
||||
if hasattr(new_cls, '__fields__'):
|
||||
# 忽略 type 字段
|
||||
new_cls.__parameter_names__ = list(new_cls.__fields__)[1:]
|
||||
else:
|
||||
new_cls.__parameter_names__ = []
|
||||
break
|
||||
|
||||
return new_cls
|
||||
|
||||
|
||||
class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass):
|
||||
"""消息组件。"""
|
||||
type: str
|
||||
"""消息组件类型。"""
|
||||
def __str__(self):
|
||||
return ''
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + ', '.join(
|
||||
(
|
||||
f'{k}={repr(v)}'
|
||||
for k, v in self.__dict__.items() if k != 'type' and v
|
||||
)
|
||||
) + ')'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# 解析参数列表,将位置参数转化为具名参数
|
||||
parameter_names = self.__parameter_names__
|
||||
if len(args) > len(parameter_names):
|
||||
raise TypeError(
|
||||
f'`{self.type}`需要{len(parameter_names)}个参数,但传入了{len(args)}个。'
|
||||
)
|
||||
for name, value in zip(parameter_names, args):
|
||||
if name in kwargs:
|
||||
raise TypeError(f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。')
|
||||
kwargs[name] = value
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
TMessageComponent = typing.TypeVar('TMessageComponent', bound=MessageComponent)
|
||||
|
||||
|
||||
class MessageChain(PlatformBaseModel):
|
||||
"""消息链。
|
||||
|
||||
一个构造消息链的例子:
|
||||
```py
|
||||
message_chain = MessageChain([
|
||||
AtAll(),
|
||||
Plain("Hello World!"),
|
||||
])
|
||||
```
|
||||
|
||||
`Plain` 可以省略。
|
||||
```py
|
||||
message_chain = MessageChain([
|
||||
AtAll(),
|
||||
"Hello World!",
|
||||
])
|
||||
```
|
||||
|
||||
在调用 API 时,参数中需要 MessageChain 的,也可以使用 `List[MessageComponent]` 代替。
|
||||
例如,以下两种写法是等价的:
|
||||
```py
|
||||
await bot.send_friend_message(12345678, [
|
||||
Plain("Hello World!")
|
||||
])
|
||||
```
|
||||
```py
|
||||
await bot.send_friend_message(12345678, MessageChain([
|
||||
Plain("Hello World!")
|
||||
]))
|
||||
```
|
||||
|
||||
可以使用 `in` 运算检查消息链中:
|
||||
1. 是否有某个消息组件。
|
||||
2. 是否有某个类型的消息组件。
|
||||
|
||||
```py
|
||||
if AtAll in message_chain:
|
||||
print('AtAll')
|
||||
|
||||
if At(bot.qq) in message_chain:
|
||||
print('At Me')
|
||||
```
|
||||
|
||||
消息链对索引操作进行了增强。以消息组件类型为索引,获取消息链中的全部该类型的消息组件。
|
||||
```py
|
||||
plain_list = message_chain[Plain]
|
||||
'[Plain("Hello World!")]'
|
||||
```
|
||||
|
||||
可以用加号连接两个消息链。
|
||||
```py
|
||||
MessageChain(['Hello World!']) + MessageChain(['Goodbye World!'])
|
||||
# 返回 MessageChain([Plain("Hello World!"), Plain("Goodbye World!")])
|
||||
```
|
||||
|
||||
"""
|
||||
__root__: typing.List[MessageComponent]
|
||||
|
||||
@staticmethod
|
||||
def _parse_message_chain(msg_chain: typing.Iterable):
|
||||
result = []
|
||||
for msg in msg_chain:
|
||||
if isinstance(msg, dict):
|
||||
result.append(MessageComponent.parse_subtype(msg))
|
||||
elif isinstance(msg, MessageComponent):
|
||||
result.append(msg)
|
||||
elif isinstance(msg, str):
|
||||
result.append(Plain(msg))
|
||||
else:
|
||||
raise TypeError(
|
||||
f"消息链中元素需为 dict 或 str 或 MessageComponent,当前类型:{type(msg)}"
|
||||
)
|
||||
return result
|
||||
|
||||
@pydantic.validator('__root__', always=True, pre=True)
|
||||
def _parse_component(cls, msg_chain):
|
||||
if isinstance(msg_chain, (str, MessageComponent)):
|
||||
msg_chain = [msg_chain]
|
||||
if not msg_chain:
|
||||
msg_chain = []
|
||||
return cls._parse_message_chain(msg_chain)
|
||||
|
||||
@classmethod
|
||||
def parse_obj(cls, msg_chain: typing.Iterable):
|
||||
"""通过列表形式的消息链,构造对应的 `MessageChain` 对象。
|
||||
|
||||
Args:
|
||||
msg_chain: 列表形式的消息链。
|
||||
"""
|
||||
result = cls._parse_message_chain(msg_chain)
|
||||
return cls(__root__=result)
|
||||
|
||||
def __init__(self, __root__: typing.Iterable[MessageComponent] = None):
|
||||
super().__init__(__root__=__root__)
|
||||
|
||||
def __str__(self):
|
||||
return "".join(str(component) for component in self.__root__)
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__root__!r})'
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.__root__
|
||||
|
||||
def get_first(self,
|
||||
t: typing.Type[TMessageComponent]) -> typing.Optional[TMessageComponent]:
|
||||
"""获取消息链中第一个符合类型的消息组件。"""
|
||||
for component in self:
|
||||
if isinstance(component, t):
|
||||
return component
|
||||
return None
|
||||
|
||||
@typing.overload
|
||||
def __getitem__(self, index: int) -> MessageComponent:
|
||||
...
|
||||
|
||||
@typing.overload
|
||||
def __getitem__(self, index: slice) -> typing.List[MessageComponent]:
|
||||
...
|
||||
|
||||
@typing.overload
|
||||
def __getitem__(self,
|
||||
index: typing.Type[TMessageComponent]) -> typing.List[TMessageComponent]:
|
||||
...
|
||||
|
||||
@typing.overload
|
||||
def __getitem__(
|
||||
self, index: typing.Tuple[typing.Type[TMessageComponent], int]
|
||||
) -> typing.List[TMessageComponent]:
|
||||
...
|
||||
|
||||
def __getitem__(
|
||||
self, index: typing.Union[int, slice, typing.Type[TMessageComponent],
|
||||
typing.Tuple[typing.Type[TMessageComponent], int]]
|
||||
) -> typing.Union[MessageComponent, typing.List[MessageComponent],
|
||||
typing.List[TMessageComponent]]:
|
||||
return self.get(index)
|
||||
|
||||
def __setitem__(
|
||||
self, key: typing.Union[int, slice],
|
||||
value: typing.Union[MessageComponent, str, typing.Iterable[typing.Union[MessageComponent,
|
||||
str]]]
|
||||
):
|
||||
if isinstance(value, str):
|
||||
value = Plain(value)
|
||||
if isinstance(value, typing.Iterable):
|
||||
value = (Plain(c) if isinstance(c, str) else c for c in value)
|
||||
self.__root__[key] = value # type: ignore
|
||||
|
||||
def __delitem__(self, key: typing.Union[int, slice]):
|
||||
del self.__root__[key]
|
||||
|
||||
def __reversed__(self) -> typing.Iterable[MessageComponent]:
|
||||
return reversed(self.__root__)
|
||||
|
||||
def has(
|
||||
self, sub: typing.Union[MessageComponent, typing.Type[MessageComponent],
|
||||
'MessageChain', str]
|
||||
) -> bool:
|
||||
"""判断消息链中:
|
||||
1. 是否有某个消息组件。
|
||||
2. 是否有某个类型的消息组件。
|
||||
|
||||
Args:
|
||||
sub (`Union[MessageComponent, Type[MessageComponent], 'MessageChain', str]`):
|
||||
若为 `MessageComponent`,则判断该组件是否在消息链中。
|
||||
若为 `Type[MessageComponent]`,则判断该组件类型是否在消息链中。
|
||||
|
||||
Returns:
|
||||
bool: 是否找到。
|
||||
"""
|
||||
if isinstance(sub, type): # 检测消息链中是否有某种类型的对象
|
||||
for i in self:
|
||||
if type(i) is sub:
|
||||
return True
|
||||
return False
|
||||
if isinstance(sub, MessageComponent): # 检查消息链中是否有某个组件
|
||||
for i in self:
|
||||
if i == sub:
|
||||
return True
|
||||
return False
|
||||
raise TypeError(f"类型不匹配,当前类型:{type(sub)}")
|
||||
|
||||
def __contains__(self, sub) -> bool:
|
||||
return self.has(sub)
|
||||
|
||||
def __ge__(self, other):
|
||||
return other in self
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.__root__)
|
||||
|
||||
def __add__(
|
||||
self, other: typing.Union['MessageChain', MessageComponent, str]
|
||||
) -> 'MessageChain':
|
||||
if isinstance(other, MessageChain):
|
||||
return self.__class__(self.__root__ + other.__root__)
|
||||
if isinstance(other, str):
|
||||
return self.__class__(self.__root__ + [Plain(other)])
|
||||
if isinstance(other, MessageComponent):
|
||||
return self.__class__(self.__root__ + [other])
|
||||
return NotImplemented
|
||||
|
||||
def __radd__(self, other: typing.Union[MessageComponent, str]) -> 'MessageChain':
|
||||
if isinstance(other, MessageComponent):
|
||||
return self.__class__([other] + self.__root__)
|
||||
if isinstance(other, str):
|
||||
return self.__class__(
|
||||
[typing.cast(MessageComponent, Plain(other))] + self.__root__
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __mul__(self, other: int):
|
||||
if isinstance(other, int):
|
||||
return self.__class__(self.__root__ * other)
|
||||
return NotImplemented
|
||||
|
||||
def __rmul__(self, other: int):
|
||||
return self.__mul__(other)
|
||||
|
||||
def __iadd__(self, other: typing.Iterable[typing.Union[MessageComponent, str]]):
|
||||
self.extend(other)
|
||||
|
||||
def __imul__(self, other: int):
|
||||
if isinstance(other, int):
|
||||
self.__root__ *= other
|
||||
return NotImplemented
|
||||
|
||||
def index(
|
||||
self,
|
||||
x: typing.Union[MessageComponent, typing.Type[MessageComponent]],
|
||||
i: int = 0,
|
||||
j: int = -1
|
||||
) -> int:
|
||||
"""返回 x 在消息链中首次出现项的索引号(索引号在 i 或其后且在 j 之前)。
|
||||
|
||||
Args:
|
||||
x (`Union[MessageComponent, Type[MessageComponent]]`):
|
||||
要查找的消息元素或消息元素类型。
|
||||
i: 从哪个位置开始查找。
|
||||
j: 查找到哪个位置结束。
|
||||
|
||||
Returns:
|
||||
int: 如果找到,则返回索引号。
|
||||
|
||||
Raises:
|
||||
ValueError: 没有找到。
|
||||
TypeError: 类型不匹配。
|
||||
"""
|
||||
if isinstance(x, type):
|
||||
l = len(self)
|
||||
if i < 0:
|
||||
i += l
|
||||
if i < 0:
|
||||
i = 0
|
||||
if j < 0:
|
||||
j += l
|
||||
if j > l:
|
||||
j = l
|
||||
for index in range(i, j):
|
||||
if type(self[index]) is x:
|
||||
return index
|
||||
raise ValueError("消息链中不存在该类型的组件。")
|
||||
if isinstance(x, MessageComponent):
|
||||
return self.__root__.index(x, i, j)
|
||||
raise TypeError(f"类型不匹配,当前类型:{type(x)}")
|
||||
|
||||
def count(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]) -> int:
|
||||
"""返回消息链中 x 出现的次数。
|
||||
|
||||
Args:
|
||||
x (`Union[MessageComponent, Type[MessageComponent]]`):
|
||||
要查找的消息元素或消息元素类型。
|
||||
|
||||
Returns:
|
||||
int: 次数。
|
||||
"""
|
||||
if isinstance(x, type):
|
||||
return sum(1 for i in self if type(i) is x)
|
||||
if isinstance(x, MessageComponent):
|
||||
return self.__root__.count(x)
|
||||
raise TypeError(f"类型不匹配,当前类型:{type(x)}")
|
||||
|
||||
def extend(self, x: typing.Iterable[typing.Union[MessageComponent, str]]):
|
||||
"""将另一个消息链中的元素添加到消息链末尾。
|
||||
|
||||
Args:
|
||||
x: 另一个消息链,也可为消息元素或字符串元素的序列。
|
||||
"""
|
||||
self.__root__.extend(Plain(c) if isinstance(c, str) else c for c in x)
|
||||
|
||||
def append(self, x: typing.Union[MessageComponent, str]):
|
||||
"""将一个消息元素或字符串元素添加到消息链末尾。
|
||||
|
||||
Args:
|
||||
x: 消息元素或字符串元素。
|
||||
"""
|
||||
self.__root__.append(Plain(x) if isinstance(x, str) else x)
|
||||
|
||||
def insert(self, i: int, x: typing.Union[MessageComponent, str]):
|
||||
"""将一个消息元素或字符串添加到消息链中指定位置。
|
||||
|
||||
Args:
|
||||
i: 插入位置。
|
||||
x: 消息元素或字符串元素。
|
||||
"""
|
||||
self.__root__.insert(i, Plain(x) if isinstance(x, str) else x)
|
||||
|
||||
def pop(self, i: int = -1) -> MessageComponent:
|
||||
"""从消息链中移除并返回指定位置的元素。
|
||||
|
||||
Args:
|
||||
i: 移除位置。默认为末尾。
|
||||
|
||||
Returns:
|
||||
MessageComponent: 移除的元素。
|
||||
"""
|
||||
return self.__root__.pop(i)
|
||||
|
||||
def remove(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]):
|
||||
"""从消息链中移除指定元素或指定类型的一个元素。
|
||||
|
||||
Args:
|
||||
x: 指定的元素或元素类型。
|
||||
"""
|
||||
if isinstance(x, type):
|
||||
self.pop(self.index(x))
|
||||
if isinstance(x, MessageComponent):
|
||||
self.__root__.remove(x)
|
||||
|
||||
def exclude(
|
||||
self,
|
||||
x: typing.Union[MessageComponent, typing.Type[MessageComponent]],
|
||||
count: int = -1
|
||||
) -> 'MessageChain':
|
||||
"""返回移除指定元素或指定类型的元素后剩余的消息链。
|
||||
|
||||
Args:
|
||||
x: 指定的元素或元素类型。
|
||||
count: 至多移除的数量。默认为全部移除。
|
||||
|
||||
Returns:
|
||||
MessageChain: 剩余的消息链。
|
||||
"""
|
||||
def _exclude():
|
||||
nonlocal count
|
||||
x_is_type = isinstance(x, type)
|
||||
for c in self:
|
||||
if count > 0 and ((x_is_type and type(c) is x) or c == x):
|
||||
count -= 1
|
||||
continue
|
||||
yield c
|
||||
|
||||
return self.__class__(_exclude())
|
||||
|
||||
def reverse(self):
|
||||
"""将消息链原地翻转。"""
|
||||
self.__root__.reverse()
|
||||
|
||||
@classmethod
|
||||
def join(cls, *args: typing.Iterable[typing.Union[str, MessageComponent]]):
|
||||
return cls(
|
||||
Plain(c) if isinstance(c, str) else c
|
||||
for c in itertools.chain(*args)
|
||||
)
|
||||
|
||||
@property
|
||||
def source(self) -> typing.Optional['Source']:
|
||||
"""获取消息链中的 `Source` 对象。"""
|
||||
return self.get_first(Source)
|
||||
|
||||
@property
|
||||
def message_id(self) -> int:
|
||||
"""获取消息链的 message_id,若无法获取,返回 -1。"""
|
||||
source = self.source
|
||||
return source.id if source else -1
|
||||
|
||||
|
||||
TMessage = typing.Union[MessageChain, typing.Iterable[typing.Union[MessageComponent, str]],
|
||||
MessageComponent, str]
|
||||
"""可以转化为 MessageChain 的类型。"""
|
||||
|
||||
|
||||
class Source(MessageComponent):
|
||||
"""源。包含消息的基本信息。"""
|
||||
type: str = "Source"
|
||||
"""消息组件类型。"""
|
||||
id: int
|
||||
"""消息的识别号,用于引用回复(Source 类型永远为 MessageChain 的第一个元素)。"""
|
||||
time: datetime
|
||||
"""消息时间。"""
|
||||
|
||||
|
||||
class Plain(MessageComponent):
|
||||
"""纯文本。"""
|
||||
type: str = "Plain"
|
||||
"""消息组件类型。"""
|
||||
text: str
|
||||
"""文字消息。"""
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
def __repr__(self):
|
||||
return f'Plain({self.text!r})'
|
||||
|
||||
|
||||
class Quote(MessageComponent):
|
||||
"""引用。"""
|
||||
type: str = "Quote"
|
||||
"""消息组件类型。"""
|
||||
id: typing.Optional[int] = None
|
||||
"""被引用回复的原消息的 message_id。"""
|
||||
group_id: typing.Optional[int] = None
|
||||
"""被引用回复的原消息所接收的群号,当为好友消息时为0。"""
|
||||
sender_id: typing.Optional[int] = None
|
||||
"""被引用回复的原消息的发送者的QQ号。"""
|
||||
target_id: typing.Optional[int] = None
|
||||
"""被引用回复的原消息的接收者者的QQ号(或群号)。"""
|
||||
origin: MessageChain
|
||||
"""被引用回复的原消息的消息链对象。"""
|
||||
|
||||
@pydantic.validator("origin", always=True, pre=True)
|
||||
def origin_formater(cls, v):
|
||||
return MessageChain.parse_obj(v)
|
||||
|
||||
|
||||
class At(MessageComponent):
|
||||
"""At某人。"""
|
||||
type: str = "At"
|
||||
"""消息组件类型。"""
|
||||
target: int
|
||||
"""群员 QQ 号。"""
|
||||
display: typing.Optional[str] = None
|
||||
"""At时显示的文字,发送消息时无效,自动使用群名片。"""
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, At) and self.target == other.target
|
||||
|
||||
def __str__(self):
|
||||
return f"@{self.display or self.target}"
|
||||
|
||||
|
||||
class AtAll(MessageComponent):
|
||||
"""At全体。"""
|
||||
type: str = "AtAll"
|
||||
"""消息组件类型。"""
|
||||
def __str__(self):
|
||||
return "@全体成员"
|
||||
|
||||
|
||||
class Image(MessageComponent):
|
||||
"""图片。"""
|
||||
type: str = "Image"
|
||||
"""消息组件类型。"""
|
||||
image_id: typing.Optional[str] = None
|
||||
"""图片的 image_id,群图片与好友图片格式不同。不为空时将忽略 url 属性。"""
|
||||
url: typing.Optional[pydantic.HttpUrl] = None
|
||||
"""图片的 URL,发送时可作网络图片的链接;接收时为腾讯图片服务器的链接,可用于图片下载。"""
|
||||
path: typing.Union[str, Path, None] = None
|
||||
"""图片的路径,发送本地图片。"""
|
||||
base64: typing.Optional[str] = None
|
||||
"""图片的 Base64 编码。"""
|
||||
def __eq__(self, other):
|
||||
return isinstance(
|
||||
other, Image
|
||||
) and self.type == other.type and self.uuid == other.uuid
|
||||
|
||||
def __str__(self):
|
||||
return '[图片]'
|
||||
|
||||
@pydantic.validator('path')
|
||||
def validate_path(cls, path: typing.Union[str, Path, None]):
|
||||
"""修复 path 参数的行为,使之相对于 QChatGPT 的启动路径。"""
|
||||
if path:
|
||||
try:
|
||||
return str(Path(path).resolve(strict=True))
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"无效路径:{path}")
|
||||
else:
|
||||
return path
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
image_id = self.image_id
|
||||
if image_id[0] == '{': # 群图片
|
||||
image_id = image_id[1:37]
|
||||
elif image_id[0] == '/': # 好友图片
|
||||
image_id = image_id[1:]
|
||||
return image_id
|
||||
|
||||
async def download(
|
||||
self,
|
||||
filename: typing.Union[str, Path, None] = None,
|
||||
directory: typing.Union[str, Path, None] = None,
|
||||
determine_type: bool = True
|
||||
):
|
||||
"""下载图片到本地。
|
||||
|
||||
Args:
|
||||
filename: 下载到本地的文件路径。与 `directory` 二选一。
|
||||
directory: 下载到本地的文件夹路径。与 `filename` 二选一。
|
||||
determine_type: 是否自动根据图片类型确定拓展名,默认为 True。
|
||||
"""
|
||||
if not self.url:
|
||||
logger.warning(f'图片 `{self.uuid}` 无 url 参数,下载失败。')
|
||||
return
|
||||
|
||||
import httpx
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self.url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
|
||||
if filename:
|
||||
path = Path(filename)
|
||||
if determine_type:
|
||||
import imghdr
|
||||
path = path.with_suffix(
|
||||
'.' + str(imghdr.what(None, content))
|
||||
)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
elif directory:
|
||||
import imghdr
|
||||
path = Path(directory)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
path = path / f'{self.uuid}.{imghdr.what(None, content)}'
|
||||
else:
|
||||
raise ValueError("请指定文件路径或文件夹路径!")
|
||||
|
||||
import aiofiles
|
||||
async with aiofiles.open(path, 'wb') as f:
|
||||
await f.write(content)
|
||||
|
||||
return path
|
||||
|
||||
@classmethod
|
||||
async def from_local(
|
||||
cls,
|
||||
filename: typing.Union[str, Path, None] = None,
|
||||
content: typing.Optional[bytes] = None,
|
||||
) -> "Image":
|
||||
"""从本地文件路径加载图片,以 base64 的形式传递。
|
||||
|
||||
Args:
|
||||
filename: 从本地文件路径加载图片,与 `content` 二选一。
|
||||
content: 从本地文件内容加载图片,与 `filename` 二选一。
|
||||
|
||||
Returns:
|
||||
Image: 图片对象。
|
||||
"""
|
||||
if content:
|
||||
pass
|
||||
elif filename:
|
||||
path = Path(filename)
|
||||
import aiofiles
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
content = await f.read()
|
||||
else:
|
||||
raise ValueError("请指定图片路径或图片内容!")
|
||||
import base64
|
||||
img = cls(base64=base64.b64encode(content).decode())
|
||||
return img
|
||||
|
||||
@classmethod
|
||||
def from_unsafe_path(cls, path: typing.Union[str, Path]) -> "Image":
|
||||
"""从不安全的路径加载图片。
|
||||
|
||||
Args:
|
||||
path: 从不安全的路径加载图片。
|
||||
|
||||
Returns:
|
||||
Image: 图片对象。
|
||||
"""
|
||||
return cls.construct(path=str(path))
|
||||
|
||||
|
||||
class Unknown(MessageComponent):
|
||||
"""未知。"""
|
||||
type: str = "Unknown"
|
||||
"""消息组件类型。"""
|
||||
text: str
|
||||
"""文本。"""
|
||||
|
||||
|
||||
class Voice(MessageComponent):
|
||||
"""语音。"""
|
||||
type: str = "Voice"
|
||||
"""消息组件类型。"""
|
||||
voice_id: typing.Optional[str] = None
|
||||
"""语音的 voice_id,不为空时将忽略 url 属性。"""
|
||||
url: typing.Optional[str] = None
|
||||
"""语音的 URL,发送时可作网络语音的链接;接收时为腾讯语音服务器的链接,可用于语音下载。"""
|
||||
path: typing.Optional[str] = None
|
||||
"""语音的路径,发送本地语音。"""
|
||||
base64: typing.Optional[str] = None
|
||||
"""语音的 Base64 编码。"""
|
||||
length: typing.Optional[int] = None
|
||||
"""语音的长度,单位为秒。"""
|
||||
@pydantic.validator('path')
|
||||
def validate_path(cls, path: typing.Optional[str]):
|
||||
"""修复 path 参数的行为,使之相对于 QChatGPT 的启动路径。"""
|
||||
if path:
|
||||
try:
|
||||
return str(Path(path).resolve(strict=True))
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"无效路径:{path}")
|
||||
else:
|
||||
return path
|
||||
|
||||
def __str__(self):
|
||||
return '[语音]'
|
||||
|
||||
async def download(
|
||||
self,
|
||||
filename: typing.Union[str, Path, None] = None,
|
||||
directory: typing.Union[str, Path, None] = None
|
||||
):
|
||||
"""下载语音到本地。
|
||||
|
||||
语音采用 silk v3 格式,silk 格式的编码解码请使用 [graiax-silkcoder](https://pypi.org/project/graiax-silkcoder/)。
|
||||
|
||||
Args:
|
||||
filename: 下载到本地的文件路径。与 `directory` 二选一。
|
||||
directory: 下载到本地的文件夹路径。与 `filename` 二选一。
|
||||
"""
|
||||
if not self.url:
|
||||
logger.warning(f'语音 `{self.voice_id}` 无 url 参数,下载失败。')
|
||||
return
|
||||
|
||||
import httpx
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self.url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
|
||||
if filename:
|
||||
path = Path(filename)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
elif directory:
|
||||
path = Path(directory)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
path = path / f'{self.voice_id}.silk'
|
||||
else:
|
||||
raise ValueError("请指定文件路径或文件夹路径!")
|
||||
|
||||
import aiofiles
|
||||
async with aiofiles.open(path, 'wb') as f:
|
||||
await f.write(content)
|
||||
|
||||
@classmethod
|
||||
async def from_local(
|
||||
cls,
|
||||
filename: typing.Union[str, Path, None] = None,
|
||||
content: typing.Optional[bytes] = None,
|
||||
) -> "Voice":
|
||||
"""从本地文件路径加载语音,以 base64 的形式传递。
|
||||
|
||||
Args:
|
||||
filename: 从本地文件路径加载语音,与 `content` 二选一。
|
||||
content: 从本地文件内容加载语音,与 `filename` 二选一。
|
||||
"""
|
||||
if content:
|
||||
pass
|
||||
if filename:
|
||||
path = Path(filename)
|
||||
import aiofiles
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
content = await f.read()
|
||||
else:
|
||||
raise ValueError("请指定语音路径或语音内容!")
|
||||
import base64
|
||||
img = cls(base64=base64.b64encode(content).decode())
|
||||
return img
|
||||
|
||||
|
||||
class ForwardMessageNode(pydantic.BaseModel):
|
||||
"""合并转发中的一条消息。"""
|
||||
sender_id: typing.Optional[int] = None
|
||||
"""发送人QQ号。"""
|
||||
sender_name: typing.Optional[str] = None
|
||||
"""显示名称。"""
|
||||
message_chain: typing.Optional[MessageChain] = None
|
||||
"""消息内容。"""
|
||||
message_id: typing.Optional[int] = None
|
||||
"""消息的 message_id,可以只使用此属性,从缓存中读取消息内容。"""
|
||||
time: typing.Optional[datetime] = None
|
||||
"""发送时间。"""
|
||||
@pydantic.validator('message_chain', check_fields=False)
|
||||
def _validate_message_chain(cls, value: typing.Union[MessageChain, list]):
|
||||
if isinstance(value, list):
|
||||
return MessageChain.parse_obj(value)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, sender: typing.Union[platform_entities.Friend, platform_entities.GroupMember], message: MessageChain
|
||||
) -> 'ForwardMessageNode':
|
||||
"""从消息链生成转发消息。
|
||||
|
||||
Args:
|
||||
sender: 发送人。
|
||||
message: 消息内容。
|
||||
|
||||
Returns:
|
||||
ForwardMessageNode: 生成的一条消息。
|
||||
"""
|
||||
return ForwardMessageNode(
|
||||
sender_id=sender.id,
|
||||
sender_name=sender.get_name(),
|
||||
message_chain=message
|
||||
)
|
||||
|
||||
|
||||
class Forward(MessageComponent):
|
||||
"""合并转发。"""
|
||||
type: str = "Forward"
|
||||
"""消息组件类型。"""
|
||||
node_list: typing.List[ForwardMessageNode]
|
||||
"""转发消息节点列表。"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if len(args) == 1:
|
||||
self.node_list = args[0]
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return '[聊天记录]'
|
||||
|
||||
|
||||
class File(MessageComponent):
|
||||
"""文件。"""
|
||||
type: str = "File"
|
||||
"""消息组件类型。"""
|
||||
id: str
|
||||
"""文件识别 ID。"""
|
||||
name: str
|
||||
"""文件名称。"""
|
||||
size: int
|
||||
"""文件大小。"""
|
||||
def __str__(self):
|
||||
return f'[文件]{self.name}'
|
||||
|
|
@ -3,11 +3,11 @@ from __future__ import annotations
|
|||
import typing
|
||||
import abc
|
||||
import pydantic
|
||||
import mirai
|
||||
|
||||
from . import events
|
||||
from ..provider.tools import entities as tools_entities
|
||||
from ..core import app
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
|
||||
def register(
|
||||
|
@ -174,11 +174,11 @@ class EventContext:
|
|||
self.__return_value__[key] = []
|
||||
self.__return_value__[key].append(ret)
|
||||
|
||||
async def reply(self, message_chain: mirai.MessageChain):
|
||||
async def reply(self, message_chain: platform_message.MessageChain):
|
||||
"""回复此次消息请求
|
||||
|
||||
Args:
|
||||
message_chain (mirai.MessageChain): YiriMirai库的消息链,若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链
|
||||
message_chain (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链
|
||||
"""
|
||||
await self.host.ap.platform_mgr.send(
|
||||
event=self.event.query.message_event,
|
||||
|
@ -190,14 +190,14 @@ class EventContext:
|
|||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: mirai.MessageChain
|
||||
message: platform_message.MessageChain
|
||||
):
|
||||
"""主动发送消息
|
||||
|
||||
Args:
|
||||
target_type (str): 目标类型,`person`或`group`
|
||||
target_id (str): 目标ID
|
||||
message (mirai.MessageChain): YiriMirai库的消息链,若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链
|
||||
message (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链
|
||||
"""
|
||||
await self.event.query.adapter.send_message(
|
||||
target_type=target_type,
|
||||
|
|
|
@ -3,10 +3,10 @@ from __future__ import annotations
|
|||
import typing
|
||||
|
||||
import pydantic
|
||||
import mirai
|
||||
|
||||
from ..core import entities as core_entities
|
||||
from ..provider import entities as llm_entities
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
|
||||
class BaseEventModel(pydantic.BaseModel):
|
||||
|
@ -31,7 +31,7 @@ class PersonMessageReceived(BaseEventModel):
|
|||
sender_id: int
|
||||
"""发送者ID(QQ号)"""
|
||||
|
||||
message_chain: mirai.MessageChain
|
||||
message_chain: platform_message.MessageChain
|
||||
|
||||
|
||||
class GroupMessageReceived(BaseEventModel):
|
||||
|
@ -43,7 +43,7 @@ class GroupMessageReceived(BaseEventModel):
|
|||
|
||||
sender_id: int
|
||||
|
||||
message_chain: mirai.MessageChain
|
||||
message_chain: platform_message.MessageChain
|
||||
|
||||
|
||||
class PersonNormalMessageReceived(BaseEventModel):
|
||||
|
|
|
@ -48,6 +48,8 @@ class PluginManager:
|
|||
# 按优先级倒序
|
||||
self.plugins.sort(key=lambda x: x.priority, reverse=True)
|
||||
|
||||
self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugins}')
|
||||
|
||||
async def initialize_plugins(self):
|
||||
for plugin in self.plugins:
|
||||
try:
|
||||
|
|
|
@ -45,6 +45,7 @@ class SettingManager:
|
|||
for plugin_container in plugin_containers:
|
||||
if plugin_container.plugin_name == value['name']:
|
||||
plugin_container.set_from_setting_dict(value)
|
||||
break
|
||||
|
||||
self.settings.data = {
|
||||
'plugins': [
|
||||
|
|
|
@ -4,7 +4,8 @@ import typing
|
|||
import enum
|
||||
import pydantic
|
||||
|
||||
import mirai
|
||||
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
|
||||
class FunctionCall(pydantic.BaseModel):
|
||||
|
@ -73,14 +74,14 @@ class Message(pydantic.BaseModel):
|
|||
|
||||
def readable_str(self) -> str:
|
||||
if self.content is not None:
|
||||
return str(self.role) + ": " + str(self.get_content_mirai_message_chain())
|
||||
return str(self.role) + ": " + str(self.get_content_platform_message_chain())
|
||||
elif self.tool_calls is not None:
|
||||
return f'调用工具: {self.tool_calls[0].id}'
|
||||
else:
|
||||
return '未知消息'
|
||||
|
||||
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None:
|
||||
"""将内容转换为 Mirai MessageChain 对象
|
||||
def get_content_platform_message_chain(self, prefix_text: str="") -> platform_message.MessageChain | None:
|
||||
"""将内容转换为平台消息 MessageChain 对象
|
||||
|
||||
Args:
|
||||
prefix_text (str): 首个文字组件的前缀文本
|
||||
|
@ -89,15 +90,15 @@ class Message(pydantic.BaseModel):
|
|||
if self.content is None:
|
||||
return None
|
||||
elif isinstance(self.content, str):
|
||||
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)])
|
||||
return platform_message.MessageChain([platform_message.Plain(prefix_text+self.content)])
|
||||
elif isinstance(self.content, list):
|
||||
mc = []
|
||||
for ce in self.content:
|
||||
if ce.type == 'text':
|
||||
mc.append(mirai.Plain(ce.text))
|
||||
mc.append(platform_message.Plain(ce.text))
|
||||
elif ce.type == 'image_url':
|
||||
if ce.image_url.url.startswith("http"):
|
||||
mc.append(mirai.Image(url=ce.image_url.url))
|
||||
mc.append(platform_message.Image(url=ce.image_url.url))
|
||||
else: # base64
|
||||
|
||||
b64_str = ce.image_url.url
|
||||
|
@ -105,15 +106,15 @@ class Message(pydantic.BaseModel):
|
|||
if b64_str.startswith("data:"):
|
||||
b64_str = b64_str.split(",")[1]
|
||||
|
||||
mc.append(mirai.Image(base64=b64_str))
|
||||
|
||||
mc.append(platform_message.Image(base64=b64_str))
|
||||
|
||||
# 找第一个文字组件
|
||||
if prefix_text:
|
||||
for i, c in enumerate(mc):
|
||||
if isinstance(c, mirai.Plain):
|
||||
mc[i] = mirai.Plain(prefix_text+c.text)
|
||||
if isinstance(c, platform_message.Plain):
|
||||
mc[i] = platform_message.Plain(prefix_text+c.text)
|
||||
break
|
||||
else:
|
||||
mc.insert(0, mirai.Plain(prefix_text))
|
||||
mc.insert(0, platform_message.Plain(prefix_text))
|
||||
|
||||
return mirai.MessageChain(mc)
|
||||
return platform_message.MessageChain(mc)
|
||||
|
|
|
@ -2,7 +2,6 @@ requests
|
|||
openai>1.0.0
|
||||
anthropic
|
||||
colorlog~=6.6.0
|
||||
yiri-mirai-rc
|
||||
aiocqhttp
|
||||
qq-botpy
|
||||
nakuru-project-idk
|
||||
|
|
Loading…
Reference in New Issue
Block a user