refactor: 独立出预处理阶段

This commit is contained in:
RockChinQ 2024-02-01 16:35:00 +08:00
parent 976a9de39c
commit 532a713355
9 changed files with 156 additions and 90 deletions

View File

@ -45,6 +45,22 @@ class Query(pydantic.BaseModel):
"""消息链platform收到的消息链"""
session: typing.Optional[Session] = None
"""会话对象,由前置处理器设置"""
messages: typing.Optional[list[llm_entities.Message]] = []
"""历史消息列表,由前置处理器设置"""
prompt: typing.Optional[sysprompt_entities.Prompt] = None
"""情景预设内容,由前置处理器设置"""
user_message: typing.Optional[llm_entities.Message] = None
"""此次请求的用户消息对象,由前置处理器设置"""
use_model: typing.Optional[entities.LLMModelInfo] = None
"""使用的模型,由前置处理器设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] = []
"""由provider生成的回复消息对象列表"""

View File

View File

@ -0,0 +1,59 @@
from __future__ import annotations
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...provider import entities as llm_entities
from ...plugin import events
@stage.stage_class("PreProcessor")
class PreProcessor(stage.PipelineStage):
"""预处理器
"""
async def process(
self,
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""处理
"""
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(session)
# 从会话取出消息和情景预设到query
query.session = session
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
query.user_message = llm_entities.Message(
role='user',
content=str(query.message_chain).strip()
)
query.use_model = conversation.use_model
query.use_funcs = conversation.use_funcs
# =========== 触发事件 PromptPreProcessing
session = query.session
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PromptPreProcessing(
session_name=f'{session.launcher_type.value}_{session.launcher_id}',
default_prompt=query.prompt.messages,
prompt=query.messages,
query=query
)
)
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt
# 根据模型max_tokens剪裁
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@ -58,38 +58,24 @@ class ChatMessageHandler(handler.MessageHandler):
mirai.Plain(event_ctx.event.alter)
])
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(session)
# =========== 触发事件 PromptPreProcessing
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PromptPreProcessing(
session_name=f'{session.launcher_type.value}_{session.launcher_id}',
default_prompt=conversation.prompt.messages,
prompt=conversation.messages,
query=query
)
query.messages.append(
query.user_message
)
conversation.prompt.messages = event_ctx.event.default_prompt
conversation.messages = event_ctx.event.prompt
conversation.messages.append(
llm_entities.Message(
role="user",
content=str(query.message_chain)
)
query.session.using_conversation.messages.append(
query.user_message
)
text_length = 0
start_time = time.time()
async for result in conversation.use_model.requester.request(query, conversation):
async for result in query.use_model.requester.request(query):
query.resp_messages.append(result)
# 消息同步到会话
query.session.using_conversation.messages.append(result)
if result.content is not None:
text_length += len(result.content)
@ -99,11 +85,11 @@ class ChatMessageHandler(handler.MessageHandler):
)
await self.ap.ctr_mgr.usage.post_query_record(
session_type=session.launcher_type.value,
session_id=str(session.launcher_id),
session_type=query.session.launcher_type.value,
session_id=str(query.session.launcher_id),
query_ability_provider="QChatGPT.Chat",
usage=text_length,
model_name=conversation.use_model.name,
model_name=query.use_model.name,
response_seconds=int(time.time() - start_time),
retry_times=-1,
)

View File

@ -11,12 +11,14 @@ from .process import process
from .longtext import longtext
from .respback import respback
from .wrapper import wrapper
from .preproc import preproc
stage_order = [
"GroupRespondRuleCheckStage",
"BanSessionCheckStage",
"PreContentFilterStage",
"PreProcessor",
"MessageProcessor",
"PostContentFilterStage",
"ResponseWrapper",

View File

@ -33,58 +33,18 @@ class ResponseWrapper(stage.PipelineStage):
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif query.resp_messages[-1].role == 'assistant':
result = query.resp_messages[-1]
session = await self.ap.sess_mgr.get_session(query)
else:
reply_text = ''
if query.resp_messages[-1].role == 'assistant':
result = query.resp_messages[-1]
session = await self.ap.sess_mgr.get_session(query)
if result.content is not None: # 有内容
reply_text = result.content
reply_text = ''
# ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
if result.content is not None: # 有内容
reply_text = result.content
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
if self.ap.cfg_mgr.data['trace_function_calls']:
# ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
@ -98,7 +58,6 @@ class ResponseWrapper(stage.PipelineStage):
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
@ -116,4 +75,47 @@ class ResponseWrapper(stage.PipelineStage):
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
)
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
if self.ap.cfg_mgr.data['trace_function_calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@ -23,7 +23,6 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def request(
self,
query: core_entities.Query,
conversation: core_entities.Conversation,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求
"""

View File

@ -7,9 +7,10 @@ import json
import openai
import openai.types.chat.chat_completion as chat_completion
from .. import api
from .. import api, entities
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
class OpenAIChatCompletion(api.LLMAPIRequester):
@ -42,15 +43,16 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
async def _closure(
self,
req_messages: list[dict],
conversation: core_entities.Conversation,
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = conversation.use_model.token_mgr.get_token()
self.client.api_key = use_model.token_mgr.get_token()
args = self.ap.cfg_mgr.data["completion_api_params"].copy()
args["model"] = conversation.use_model.name if conversation.use_model.model_name is None else conversation.use_model.model_name
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if conversation.use_model.tool_call_supported:
tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation)
if use_model.tool_call_supported:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
@ -68,19 +70,19 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
return message
async def request(
self, query: core_entities.Query, conversation: core_entities.Conversation
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求"""
pending_tool_calls = []
req_messages = [
m.dict(exclude_none=True) for m in conversation.prompt.messages
] + [m.dict(exclude_none=True) for m in conversation.messages]
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages
] + [m.dict(exclude_none=True) for m in query.messages]
# req_messages.append({"role": "user", "content": str(query.message_chain)})
msg = await self._closure(req_messages, conversation)
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
@ -107,7 +109,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
req_messages.append(msg.dict(exclude_none=True))
# 处理完所有调用,继续请求
msg = await self._closure(req_messages, conversation)
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg

View File

@ -38,12 +38,12 @@ class ToolManager:
return all_functions
async def generate_tools_for_openai(self, conversation: core_entities.Conversation) -> str:
async def generate_tools_for_openai(self, use_funcs: entities.LLMFunction) -> str:
"""生成函数列表
"""
tools = []
for function in conversation.use_funcs:
for function in use_funcs:
if function.enable:
function_schema = {
"type": "function",