mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 03:32:33 +08:00
feat: runner 层抽象 (#839)
This commit is contained in:
parent
48cc3656bd
commit
8cad4089a7
|
@ -9,6 +9,7 @@ from ..provider.session import sessionmgr as llm_session_mgr
|
||||||
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||||
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
|
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||||
|
from ..provider import runnermgr
|
||||||
from ..config import manager as config_mgr
|
from ..config import manager as config_mgr
|
||||||
from ..audit.center import v2 as center_mgr
|
from ..audit.center import v2 as center_mgr
|
||||||
from ..command import cmdmgr
|
from ..command import cmdmgr
|
||||||
|
@ -33,6 +34,8 @@ class Application:
|
||||||
|
|
||||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||||
|
|
||||||
|
runner_mgr: runnermgr.RunnerManager = None
|
||||||
|
|
||||||
# ======= 配置管理器 =======
|
# ======= 配置管理器 =======
|
||||||
|
|
||||||
command_cfg: config_mgr.ConfigManager = None
|
command_cfg: config_mgr.ConfigManager = None
|
||||||
|
|
19
pkg/core/migrations/m012_runner_config.py
Normal file
19
pkg/core/migrations/m012_runner_config.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
|
@migration.migration_class("runner-config", 12)
|
||||||
|
class RunnerConfigMigration(migration.Migration):
|
||||||
|
"""迁移"""
|
||||||
|
|
||||||
|
async def need_migrate(self) -> bool:
|
||||||
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
|
return 'runner' not in self.ap.provider_cfg.data
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""执行迁移"""
|
||||||
|
|
||||||
|
self.ap.provider_cfg.data['runner'] = 'local-agent'
|
||||||
|
|
||||||
|
await self.ap.provider_cfg.dump_config()
|
|
@ -13,6 +13,7 @@ from ...provider.session import sessionmgr as llm_session_mgr
|
||||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||||
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
|
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||||
|
from ...provider import runnermgr
|
||||||
from ...platform import manager as im_mgr
|
from ...platform import manager as im_mgr
|
||||||
|
|
||||||
@stage.stage_class("BuildAppStage")
|
@stage.stage_class("BuildAppStage")
|
||||||
|
@ -81,6 +82,11 @@ class BuildAppStage(stage.BootingStage):
|
||||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||||
await llm_tool_mgr_inst.initialize()
|
await llm_tool_mgr_inst.initialize()
|
||||||
ap.tool_mgr = llm_tool_mgr_inst
|
ap.tool_mgr = llm_tool_mgr_inst
|
||||||
|
|
||||||
|
runner_mgr_inst = runnermgr.RunnerManager(ap)
|
||||||
|
await runner_mgr_inst.initialize()
|
||||||
|
ap.runner_mgr = runner_mgr_inst
|
||||||
|
|
||||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||||
await im_mgr_inst.initialize()
|
await im_mgr_inst.initialize()
|
||||||
ap.platform_mgr = im_mgr_inst
|
ap.platform_mgr = im_mgr_inst
|
||||||
|
|
|
@ -6,7 +6,7 @@ from .. import stage, app
|
||||||
from .. import migration
|
from .. import migration
|
||||||
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
|
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
|
||||||
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
|
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
|
||||||
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config
|
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class("MigrationStage")
|
@stage.stage_class("MigrationStage")
|
||||||
|
|
|
@ -10,7 +10,7 @@ import mirai
|
||||||
from .. import handler
|
from .. import handler
|
||||||
from ... import entities
|
from ... import entities
|
||||||
from ....core import entities as core_entities
|
from ....core import entities as core_entities
|
||||||
from ....provider import entities as llm_entities
|
from ....provider import entities as llm_entities, runnermgr
|
||||||
from ....plugin import events
|
from ....plugin import events
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,7 +71,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
async for result in self.runner(query):
|
runner = self.ap.runner_mgr.get_runner()
|
||||||
|
|
||||||
|
async for result in runner.run(query):
|
||||||
query.resp_messages.append(result)
|
query.resp_messages.append(result)
|
||||||
|
|
||||||
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
||||||
|
@ -108,64 +110,3 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||||
response_seconds=int(time.time() - start_time),
|
response_seconds=int(time.time() - start_time),
|
||||||
retry_times=-1,
|
retry_times=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def runner(
|
|
||||||
self,
|
|
||||||
query: core_entities.Query,
|
|
||||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
|
||||||
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
|
|
||||||
|
|
||||||
这是临时处理方案,后续可能改为使用LangChain或者自研的工作流处理器
|
|
||||||
"""
|
|
||||||
await query.use_model.requester.preprocess(query)
|
|
||||||
|
|
||||||
pending_tool_calls = []
|
|
||||||
|
|
||||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
|
||||||
|
|
||||||
# 首次请求
|
|
||||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
|
||||||
|
|
||||||
yield msg
|
|
||||||
|
|
||||||
pending_tool_calls = msg.tool_calls
|
|
||||||
|
|
||||||
req_messages.append(msg)
|
|
||||||
|
|
||||||
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
|
||||||
while pending_tool_calls:
|
|
||||||
for tool_call in pending_tool_calls:
|
|
||||||
try:
|
|
||||||
func = tool_call.function
|
|
||||||
|
|
||||||
parameters = json.loads(func.arguments)
|
|
||||||
|
|
||||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
|
||||||
query, func.name, parameters
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = llm_entities.Message(
|
|
||||||
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield msg
|
|
||||||
|
|
||||||
req_messages.append(msg)
|
|
||||||
except Exception as e:
|
|
||||||
# 工具调用出错,添加一个报错信息到 req_messages
|
|
||||||
err_msg = llm_entities.Message(
|
|
||||||
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield err_msg
|
|
||||||
|
|
||||||
req_messages.append(err_msg)
|
|
||||||
|
|
||||||
# 处理完所有调用,再次请求
|
|
||||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
|
||||||
|
|
||||||
yield msg
|
|
||||||
|
|
||||||
pending_tool_calls = msg.tool_calls
|
|
||||||
|
|
||||||
req_messages.append(msg)
|
|
||||||
|
|
40
pkg/provider/runner.py
Normal file
40
pkg/provider/runner.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from ..core import app, entities as core_entities
|
||||||
|
from . import entities as llm_entities
|
||||||
|
|
||||||
|
|
||||||
|
preregistered_runners: list[typing.Type[RequestRunner]] = []
|
||||||
|
|
||||||
|
def runner_class(name: str):
|
||||||
|
"""注册一个请求运行器
|
||||||
|
"""
|
||||||
|
def decorator(cls: typing.Type[RequestRunner]) -> typing.Type[RequestRunner]:
|
||||||
|
cls.name = name
|
||||||
|
preregistered_runners.append(cls)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class RequestRunner(abc.ABC):
|
||||||
|
"""请求运行器
|
||||||
|
"""
|
||||||
|
name: str = None
|
||||||
|
|
||||||
|
ap: app.Application
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application):
|
||||||
|
self.ap = ap
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
|
"""运行请求
|
||||||
|
"""
|
||||||
|
pass
|
27
pkg/provider/runnermgr.py
Normal file
27
pkg/provider/runnermgr.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from . import runner
|
||||||
|
from ..core import app
|
||||||
|
|
||||||
|
from .runners import localagent
|
||||||
|
|
||||||
|
|
||||||
|
class RunnerManager:
|
||||||
|
|
||||||
|
ap: app.Application
|
||||||
|
|
||||||
|
using_runner: runner.RequestRunner
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application):
|
||||||
|
self.ap = ap
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
|
||||||
|
for r in runner.preregistered_runners:
|
||||||
|
if r.name == self.ap.provider_cfg.data['runner']:
|
||||||
|
self.using_runner = r(self.ap)
|
||||||
|
await self.using_runner.initialize()
|
||||||
|
break
|
||||||
|
|
||||||
|
def get_runner(self) -> runner.RequestRunner:
|
||||||
|
return self.using_runner
|
0
pkg/provider/runners/__init__.py
Normal file
0
pkg/provider/runners/__init__.py
Normal file
70
pkg/provider/runners/localagent.py
Normal file
70
pkg/provider/runners/localagent.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from .. import runner
|
||||||
|
from ...core import app, entities as core_entities
|
||||||
|
from .. import entities as llm_entities
|
||||||
|
|
||||||
|
|
||||||
|
@runner.runner_class("local-agent")
|
||||||
|
class LocalAgentRunner(runner.RequestRunner):
|
||||||
|
"""本地Agent请求运行器
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
|
"""运行请求
|
||||||
|
"""
|
||||||
|
await query.use_model.requester.preprocess(query)
|
||||||
|
|
||||||
|
pending_tool_calls = []
|
||||||
|
|
||||||
|
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||||
|
|
||||||
|
# 首次请求
|
||||||
|
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||||
|
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
pending_tool_calls = msg.tool_calls
|
||||||
|
|
||||||
|
req_messages.append(msg)
|
||||||
|
|
||||||
|
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
||||||
|
while pending_tool_calls:
|
||||||
|
for tool_call in pending_tool_calls:
|
||||||
|
try:
|
||||||
|
func = tool_call.function
|
||||||
|
|
||||||
|
parameters = json.loads(func.arguments)
|
||||||
|
|
||||||
|
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||||
|
query, func.name, parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = llm_entities.Message(
|
||||||
|
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
||||||
|
)
|
||||||
|
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
req_messages.append(msg)
|
||||||
|
except Exception as e:
|
||||||
|
# 工具调用出错,添加一个报错信息到 req_messages
|
||||||
|
err_msg = llm_entities.Message(
|
||||||
|
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
|
||||||
|
)
|
||||||
|
|
||||||
|
yield err_msg
|
||||||
|
|
||||||
|
req_messages.append(err_msg)
|
||||||
|
|
||||||
|
# 处理完所有调用,再次请求
|
||||||
|
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||||
|
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
pending_tool_calls = msg.tool_calls
|
||||||
|
|
||||||
|
req_messages.append(msg)
|
|
@ -48,5 +48,6 @@
|
||||||
"prompt-mode": "normal",
|
"prompt-mode": "normal",
|
||||||
"prompt": {
|
"prompt": {
|
||||||
"default": ""
|
"default": ""
|
||||||
}
|
},
|
||||||
|
"runner": "local-agent"
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user