mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
refactor: 独立出预处理阶段
This commit is contained in:
parent
976a9de39c
commit
532a713355
|
@ -45,6 +45,22 @@ class Query(pydantic.BaseModel):
|
|||
"""消息链,platform收到的消息链"""
|
||||
|
||||
session: typing.Optional[Session] = None
|
||||
"""会话对象,由前置处理器设置"""
|
||||
|
||||
messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""历史消息列表,由前置处理器设置"""
|
||||
|
||||
prompt: typing.Optional[sysprompt_entities.Prompt] = None
|
||||
"""情景预设内容,由前置处理器设置"""
|
||||
|
||||
user_message: typing.Optional[llm_entities.Message] = None
|
||||
"""此次请求的用户消息对象,由前置处理器设置"""
|
||||
|
||||
use_model: typing.Optional[entities.LLMModelInfo] = None
|
||||
"""使用的模型,由前置处理器设置"""
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
|
||||
"""使用的函数,由前置处理器设置"""
|
||||
|
||||
resp_messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""由provider生成的回复消息对象列表"""
|
||||
|
|
0
pkg/pipeline/preproc/__init__.py
Normal file
0
pkg/pipeline/preproc/__init__.py
Normal file
59
pkg/pipeline/preproc/preproc.py
Normal file
59
pkg/pipeline/preproc/preproc.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...provider import entities as llm_entities
|
||||
from ...plugin import events
|
||||
|
||||
|
||||
@stage.stage_class("PreProcessor")
|
||||
class PreProcessor(stage.PipelineStage):
|
||||
"""预处理器
|
||||
"""
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str,
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(session)
|
||||
|
||||
# 从会话取出消息和情景预设到query
|
||||
query.session = session
|
||||
query.prompt = conversation.prompt.copy()
|
||||
query.messages = conversation.messages.copy()
|
||||
|
||||
query.user_message = llm_entities.Message(
|
||||
role='user',
|
||||
content=str(query.message_chain).strip()
|
||||
)
|
||||
|
||||
query.use_model = conversation.use_model
|
||||
|
||||
query.use_funcs = conversation.use_funcs
|
||||
|
||||
# =========== 触发事件 PromptPreProcessing
|
||||
session = query.session
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PromptPreProcessing(
|
||||
session_name=f'{session.launcher_type.value}_{session.launcher_id}',
|
||||
default_prompt=query.prompt.messages,
|
||||
prompt=query.messages,
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
query.prompt.messages = event_ctx.event.default_prompt
|
||||
query.messages = event_ctx.event.prompt
|
||||
|
||||
# 根据模型max_tokens剪裁
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
|
@ -58,38 +58,24 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
mirai.Plain(event_ctx.event.alter)
|
||||
])
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(session)
|
||||
|
||||
# =========== 触发事件 PromptPreProcessing
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PromptPreProcessing(
|
||||
session_name=f'{session.launcher_type.value}_{session.launcher_id}',
|
||||
default_prompt=conversation.prompt.messages,
|
||||
prompt=conversation.messages,
|
||||
query=query
|
||||
)
|
||||
query.messages.append(
|
||||
query.user_message
|
||||
)
|
||||
|
||||
conversation.prompt.messages = event_ctx.event.default_prompt
|
||||
conversation.messages = event_ctx.event.prompt
|
||||
|
||||
conversation.messages.append(
|
||||
llm_entities.Message(
|
||||
role="user",
|
||||
content=str(query.message_chain)
|
||||
)
|
||||
query.session.using_conversation.messages.append(
|
||||
query.user_message
|
||||
)
|
||||
|
||||
text_length = 0
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
async for result in conversation.use_model.requester.request(query, conversation):
|
||||
async for result in query.use_model.requester.request(query):
|
||||
query.resp_messages.append(result)
|
||||
|
||||
# 消息同步到会话
|
||||
query.session.using_conversation.messages.append(result)
|
||||
|
||||
if result.content is not None:
|
||||
text_length += len(result.content)
|
||||
|
||||
|
@ -99,11 +85,11 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
)
|
||||
|
||||
await self.ap.ctr_mgr.usage.post_query_record(
|
||||
session_type=session.launcher_type.value,
|
||||
session_id=str(session.launcher_id),
|
||||
session_type=query.session.launcher_type.value,
|
||||
session_id=str(query.session.launcher_id),
|
||||
query_ability_provider="QChatGPT.Chat",
|
||||
usage=text_length,
|
||||
model_name=conversation.use_model.name,
|
||||
model_name=query.use_model.name,
|
||||
response_seconds=int(time.time() - start_time),
|
||||
retry_times=-1,
|
||||
)
|
|
@ -11,12 +11,14 @@ from .process import process
|
|||
from .longtext import longtext
|
||||
from .respback import respback
|
||||
from .wrapper import wrapper
|
||||
from .preproc import preproc
|
||||
|
||||
|
||||
stage_order = [
|
||||
"GroupRespondRuleCheckStage",
|
||||
"BanSessionCheckStage",
|
||||
"PreContentFilterStage",
|
||||
"PreProcessor",
|
||||
"MessageProcessor",
|
||||
"PostContentFilterStage",
|
||||
"ResponseWrapper",
|
||||
|
|
|
@ -33,58 +33,18 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif query.resp_messages[-1].role == 'assistant':
|
||||
result = query.resp_messages[-1]
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
else:
|
||||
|
||||
reply_text = ''
|
||||
if query.resp_messages[-1].role == 'assistant':
|
||||
result = query.resp_messages[-1]
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
if result.content is not None: # 有内容
|
||||
reply_text = result.content
|
||||
reply_text = ''
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
|
||||
query=query
|
||||
)
|
||||
)
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
if result.content is not None: # 有内容
|
||||
reply_text = result.content
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
if result.tool_calls is not None: # 有函数调用
|
||||
|
||||
function_names = [tc.function.name for tc in result.tool_calls]
|
||||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
if self.ap.cfg_mgr.data['trace_function_calls']:
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
|
@ -98,7 +58,6 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
|
@ -116,4 +75,47 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
)
|
||||
|
||||
if result.tool_calls is not None: # 有函数调用
|
||||
|
||||
function_names = [tc.function.name for tc in result.tool_calls]
|
||||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
if self.ap.cfg_mgr.data['trace_function_calls']:
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
|
@ -23,7 +23,6 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
|
|||
async def request(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
conversation: core_entities.Conversation,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求
|
||||
"""
|
||||
|
|
|
@ -7,9 +7,10 @@ import json
|
|||
import openai
|
||||
import openai.types.chat.chat_completion as chat_completion
|
||||
|
||||
from .. import api
|
||||
from .. import api, entities
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
|
||||
class OpenAIChatCompletion(api.LLMAPIRequester):
|
||||
|
@ -42,15 +43,16 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
|
|||
async def _closure(
|
||||
self,
|
||||
req_messages: list[dict],
|
||||
conversation: core_entities.Conversation,
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = conversation.use_model.token_mgr.get_token()
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.ap.cfg_mgr.data["completion_api_params"].copy()
|
||||
args["model"] = conversation.use_model.name if conversation.use_model.model_name is None else conversation.use_model.model_name
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
if conversation.use_model.tool_call_supported:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation)
|
||||
if use_model.tool_call_supported:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
@ -68,19 +70,19 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
|
|||
return message
|
||||
|
||||
async def request(
|
||||
self, query: core_entities.Query, conversation: core_entities.Conversation
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求"""
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = [
|
||||
m.dict(exclude_none=True) for m in conversation.prompt.messages
|
||||
] + [m.dict(exclude_none=True) for m in conversation.messages]
|
||||
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
m.dict(exclude_none=True) for m in query.prompt.messages
|
||||
] + [m.dict(exclude_none=True) for m in query.messages]
|
||||
|
||||
# req_messages.append({"role": "user", "content": str(query.message_chain)})
|
||||
|
||||
msg = await self._closure(req_messages, conversation)
|
||||
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
|
@ -107,7 +109,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
|
|||
req_messages.append(msg.dict(exclude_none=True))
|
||||
|
||||
# 处理完所有调用,继续请求
|
||||
msg = await self._closure(req_messages, conversation)
|
||||
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
|
|
|
@ -38,12 +38,12 @@ class ToolManager:
|
|||
|
||||
return all_functions
|
||||
|
||||
async def generate_tools_for_openai(self, conversation: core_entities.Conversation) -> str:
|
||||
async def generate_tools_for_openai(self, use_funcs: entities.LLMFunction) -> str:
|
||||
"""生成函数列表
|
||||
"""
|
||||
tools = []
|
||||
|
||||
for function in conversation.use_funcs:
|
||||
for function in use_funcs:
|
||||
if function.enable:
|
||||
function_schema = {
|
||||
"type": "function",
|
||||
|
|
Loading…
Reference in New Issue
Block a user