From 532a713355c37c44444881939ef23450c2931456 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 1 Feb 2024 16:35:00 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=8B=AC=E7=AB=8B=E5=87=BA?= =?UTF-8?q?=E9=A2=84=E5=A4=84=E7=90=86=E9=98=B6=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/entities.py | 16 ++++ pkg/pipeline/preproc/__init__.py | 0 pkg/pipeline/preproc/preproc.py | 59 ++++++++++++++ pkg/pipeline/process/handlers/chat.py | 36 +++------ pkg/pipeline/stagemgr.py | 2 + pkg/pipeline/wrapper/wrapper.py | 102 ++++++++++++------------ pkg/provider/requester/api.py | 1 - pkg/provider/requester/apis/chatcmpl.py | 26 +++--- pkg/provider/tools/toolmgr.py | 4 +- 9 files changed, 156 insertions(+), 90 deletions(-) create mode 100644 pkg/pipeline/preproc/__init__.py create mode 100644 pkg/pipeline/preproc/preproc.py diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 8e25750..53d515e 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -45,6 +45,22 @@ class Query(pydantic.BaseModel): """消息链,platform收到的消息链""" session: typing.Optional[Session] = None + """会话对象,由前置处理器设置""" + + messages: typing.Optional[list[llm_entities.Message]] = [] + """历史消息列表,由前置处理器设置""" + + prompt: typing.Optional[sysprompt_entities.Prompt] = None + """情景预设内容,由前置处理器设置""" + + user_message: typing.Optional[llm_entities.Message] = None + """此次请求的用户消息对象,由前置处理器设置""" + + use_model: typing.Optional[entities.LLMModelInfo] = None + """使用的模型,由前置处理器设置""" + + use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None + """使用的函数,由前置处理器设置""" resp_messages: typing.Optional[list[llm_entities.Message]] = [] """由provider生成的回复消息对象列表""" diff --git a/pkg/pipeline/preproc/__init__.py b/pkg/pipeline/preproc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py new file mode 100644 index 0000000..578c743 --- /dev/null +++ b/pkg/pipeline/preproc/preproc.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...provider import entities as llm_entities +from ...plugin import events + + +@stage.stage_class("PreProcessor") +class PreProcessor(stage.PipelineStage): + """预处理器 + """ + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> entities.StageProcessResult: + """处理 + """ + session = await self.ap.sess_mgr.get_session(query) + + conversation = await self.ap.sess_mgr.get_conversation(session) + + # 从会话取出消息和情景预设到query + query.session = session + query.prompt = conversation.prompt.copy() + query.messages = conversation.messages.copy() + + query.user_message = llm_entities.Message( + role='user', + content=str(query.message_chain).strip() + ) + + query.use_model = conversation.use_model + + query.use_funcs = conversation.use_funcs + + # =========== 触发事件 PromptPreProcessing + session = query.session + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.PromptPreProcessing( + session_name=f'{session.launcher_type.value}_{session.launcher_id}', + default_prompt=query.prompt.messages, + prompt=query.messages, + query=query + ) + ) + + query.prompt.messages = event_ctx.event.default_prompt + query.messages = event_ctx.event.prompt + + # 根据模型max_tokens剪裁 + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 603d1f1..bcacf93 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -58,38 +58,24 @@ class ChatMessageHandler(handler.MessageHandler): mirai.Plain(event_ctx.event.alter) ]) - session = await self.ap.sess_mgr.get_session(query) - - conversation = await self.ap.sess_mgr.get_conversation(session) - - # =========== 触发事件 PromptPreProcessing - - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.PromptPreProcessing( - session_name=f'{session.launcher_type.value}_{session.launcher_id}', - default_prompt=conversation.prompt.messages, - prompt=conversation.messages, - query=query - ) + query.messages.append( + query.user_message ) - conversation.prompt.messages = event_ctx.event.default_prompt - conversation.messages = event_ctx.event.prompt - - conversation.messages.append( - llm_entities.Message( - role="user", - content=str(query.message_chain) - ) + query.session.using_conversation.messages.append( + query.user_message ) text_length = 0 start_time = time.time() - async for result in conversation.use_model.requester.request(query, conversation): + async for result in query.use_model.requester.request(query): query.resp_messages.append(result) + # 消息同步到会话 + query.session.using_conversation.messages.append(result) + if result.content is not None: text_length += len(result.content) @@ -99,11 +85,11 @@ class ChatMessageHandler(handler.MessageHandler): ) await self.ap.ctr_mgr.usage.post_query_record( - session_type=session.launcher_type.value, - session_id=str(session.launcher_id), + session_type=query.session.launcher_type.value, + session_id=str(query.session.launcher_id), query_ability_provider="QChatGPT.Chat", usage=text_length, - model_name=conversation.use_model.name, + model_name=query.use_model.name, response_seconds=int(time.time() - start_time), retry_times=-1, ) \ No newline at end of file diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py index 24cb20f..b3faaf9 100644 --- a/pkg/pipeline/stagemgr.py +++ b/pkg/pipeline/stagemgr.py @@ -11,12 +11,14 @@ from .process import process from .longtext import longtext from .respback import respback from .wrapper import wrapper +from .preproc import preproc stage_order = [ "GroupRespondRuleCheckStage", "BanSessionCheckStage", "PreContentFilterStage", + "PreProcessor", "MessageProcessor", "PostContentFilterStage", "ResponseWrapper", diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index e333f00..0625a3d 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -33,58 +33,18 @@ class ResponseWrapper(stage.PipelineStage): result_type=entities.ResultType.CONTINUE, new_query=query ) - elif query.resp_messages[-1].role == 'assistant': - result = query.resp_messages[-1] - session = await self.ap.sess_mgr.get_session(query) + else: - reply_text = '' + if query.resp_messages[-1].role == 'assistant': + result = query.resp_messages[-1] + session = await self.ap.sess_mgr.get_session(query) - if result.content is not None: # 有内容 - reply_text = result.content + reply_text = '' - # ============= 触发插件事件 =============== - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.NormalMessageResponded( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - session=session, - prefix='', - response_text=reply_text, - finish_reason='stop', - funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], - query=query - ) - ) - if event_ctx.is_prevented_default(): - yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query - ) - else: - if event_ctx.event.reply is not None: - - query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + if result.content is not None: # 有内容 + reply_text = result.content - else: - - query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) - - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) - - if result.tool_calls is not None: # 有函数调用 - - function_names = [tc.function.name for tc in result.tool_calls] - - reply_text = f'调用函数 {".".join(function_names)}...' - - query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) - - if self.ap.cfg_mgr.data['trace_function_calls']: - + # ============= 触发插件事件 =============== event_ctx = await self.ap.plugin_mgr.emit_event( event=events.NormalMessageResponded( launcher_type=query.launcher_type.value, @@ -98,7 +58,6 @@ class ResponseWrapper(stage.PipelineStage): query=query ) ) - if event_ctx.is_prevented_default(): yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, @@ -116,4 +75,47 @@ class ResponseWrapper(stage.PipelineStage): yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query - ) \ No newline at end of file + ) + + if result.tool_calls is not None: # 有函数调用 + + function_names = [tc.function.name for tc in result.tool_calls] + + reply_text = f'调用函数 {".".join(function_names)}...' + + query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + + if self.ap.cfg_mgr.data['trace_function_calls']: + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.NormalMessageResponded( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + session=session, + prefix='', + response_text=reply_text, + finish_reason='stop', + funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], + query=query + ) + ) + + if event_ctx.is_prevented_default(): + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) + else: + if event_ctx.event.reply is not None: + + query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + + else: + + query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/provider/requester/api.py b/pkg/provider/requester/api.py index 88d500e..88ba78c 100644 --- a/pkg/provider/requester/api.py +++ b/pkg/provider/requester/api.py @@ -23,7 +23,6 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): async def request( self, query: core_entities.Query, - conversation: core_entities.Conversation, ) -> typing.AsyncGenerator[llm_entities.Message, None]: """请求 """ diff --git a/pkg/provider/requester/apis/chatcmpl.py b/pkg/provider/requester/apis/chatcmpl.py index 52b895c..5101ad3 100644 --- a/pkg/provider/requester/apis/chatcmpl.py +++ b/pkg/provider/requester/apis/chatcmpl.py @@ -7,9 +7,10 @@ import json import openai import openai.types.chat.chat_completion as chat_completion -from .. import api +from .. import api, entities from ....core import entities as core_entities from ... import entities as llm_entities +from ...tools import entities as tools_entities class OpenAIChatCompletion(api.LLMAPIRequester): @@ -42,15 +43,16 @@ class OpenAIChatCompletion(api.LLMAPIRequester): async def _closure( self, req_messages: list[dict], - conversation: core_entities.Conversation, + use_model: entities.LLMModelInfo, + use_funcs: list[tools_entities.LLMFunction] = None, ) -> llm_entities.Message: - self.client.api_key = conversation.use_model.token_mgr.get_token() + self.client.api_key = use_model.token_mgr.get_token() args = self.ap.cfg_mgr.data["completion_api_params"].copy() - args["model"] = conversation.use_model.name if conversation.use_model.model_name is None else conversation.use_model.model_name + args["model"] = use_model.name if use_model.model_name is None else use_model.model_name - if conversation.use_model.tool_call_supported: - tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation) + if use_model.tool_call_supported: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) if tools: args["tools"] = tools @@ -68,19 +70,19 @@ class OpenAIChatCompletion(api.LLMAPIRequester): return message async def request( - self, query: core_entities.Query, conversation: core_entities.Conversation + self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """请求""" pending_tool_calls = [] - req_messages = [ - m.dict(exclude_none=True) for m in conversation.prompt.messages - ] + [m.dict(exclude_none=True) for m in conversation.messages] + req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 + m.dict(exclude_none=True) for m in query.prompt.messages + ] + [m.dict(exclude_none=True) for m in query.messages] # req_messages.append({"role": "user", "content": str(query.message_chain)}) - msg = await self._closure(req_messages, conversation) + msg = await self._closure(req_messages, query.use_model, query.use_funcs) yield msg @@ -107,7 +109,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester): req_messages.append(msg.dict(exclude_none=True)) # 处理完所有调用,继续请求 - msg = await self._closure(req_messages, conversation) + msg = await self._closure(req_messages, query.use_model, query.use_funcs) yield msg diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index e0a1ec2..72c892b 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -38,12 +38,12 @@ class ToolManager: return all_functions - async def generate_tools_for_openai(self, conversation: core_entities.Conversation) -> str: + async def generate_tools_for_openai(self, use_funcs: entities.LLMFunction) -> str: """生成函数列表 """ tools = [] - for function in conversation.use_funcs: + for function in use_funcs: if function.enable: function_schema = { "type": "function",