mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
commit
a7f830dd73
|
@ -13,11 +13,16 @@ class CommandReturn(pydantic.BaseModel):
|
|||
"""命令返回值
|
||||
"""
|
||||
|
||||
text: typing.Optional[str]
|
||||
text: typing.Optional[str] = None
|
||||
"""文本
|
||||
"""
|
||||
|
||||
image: typing.Optional[mirai.Image]
|
||||
image: typing.Optional[mirai.Image] = None
|
||||
"""弃用"""
|
||||
|
||||
image_url: typing.Optional[str] = None
|
||||
"""图片链接
|
||||
"""
|
||||
|
||||
error: typing.Optional[errors.CommandError]= None
|
||||
"""错误
|
||||
|
|
|
@ -24,7 +24,7 @@ class DefaultOperator(operator.CommandOperator):
|
|||
|
||||
content = ""
|
||||
for msg in prompt.messages:
|
||||
content += f" {msg.role}: {msg.content}"
|
||||
content += f" {msg.readable_str()}\n"
|
||||
|
||||
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
|
||||
|
||||
|
@ -45,18 +45,18 @@ class DefaultSetOperator(operator.CommandOperator):
|
|||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
|
||||
else:
|
||||
prompt_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
|
||||
if prompt is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
|
||||
else:
|
||||
context.session.use_prompt_name = prompt.name
|
||||
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
|
||||
else:
|
||||
prompt_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
|
||||
if prompt is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
|
||||
else:
|
||||
context.session.use_prompt_name = prompt.name
|
||||
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))
|
||||
|
|
|
@ -30,7 +30,7 @@ class LastOperator(operator.CommandOperator):
|
|||
context.session.using_conversation = context.session.conversations[index-1]
|
||||
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
|
||||
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}")
|
||||
return
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
|
@ -42,7 +42,7 @@ class ListOperator(operator.CommandOperator):
|
|||
using_conv_index = index
|
||||
|
||||
if index >= page * record_per_page and index < (page + 1) * record_per_page:
|
||||
content += f"{index} {time_str}: {conv.messages[0].content if len(conv.messages) > 0 else '无内容'}\n"
|
||||
content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n"
|
||||
index += 1
|
||||
|
||||
if content == '':
|
||||
|
@ -51,6 +51,6 @@ class ListOperator(operator.CommandOperator):
|
|||
if context.session.using_conversation is None:
|
||||
content += "\n当前处于新会话"
|
||||
else:
|
||||
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content if len(context.session.using_conversation.messages) > 0 else '无内容'}"
|
||||
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}"
|
||||
|
||||
yield entities.CommandReturn(text=f"第 {page + 1} 页 (时间倒序):\n{content}")
|
||||
|
|
|
@ -27,7 +27,7 @@ class JSONConfigFile(file_model.ConfigFile):
|
|||
else:
|
||||
raise ValueError("template_file_name or template_data must be provided")
|
||||
|
||||
async def load(self) -> dict:
|
||||
async def load(self, completion: bool=True) -> dict:
|
||||
|
||||
if not self.exists():
|
||||
await self.create()
|
||||
|
@ -39,9 +39,11 @@ class JSONConfigFile(file_model.ConfigFile):
|
|||
with open(self.config_file_name, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
for key in self.template_data:
|
||||
if key not in cfg:
|
||||
cfg[key] = self.template_data[key]
|
||||
if completion:
|
||||
|
||||
for key in self.template_data:
|
||||
if key not in cfg:
|
||||
cfg[key] = self.template_data[key]
|
||||
|
||||
return cfg
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile):
|
|||
async def create(self):
|
||||
shutil.copyfile(self.template_file_name, self.config_file_name)
|
||||
|
||||
async def load(self) -> dict:
|
||||
async def load(self, completion: bool=True) -> dict:
|
||||
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
|
@ -43,18 +43,19 @@ class PythonModuleConfigFile(file_model.ConfigFile):
|
|||
cfg[key] = getattr(module, key)
|
||||
|
||||
# 从模板模块文件中进行补全
|
||||
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
|
||||
module = importlib.import_module(module_name)
|
||||
if completion:
|
||||
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
for key in dir(module):
|
||||
if key.startswith('__'):
|
||||
continue
|
||||
for key in dir(module):
|
||||
if key.startswith('__'):
|
||||
continue
|
||||
|
||||
if not isinstance(getattr(module, key), allowed_types):
|
||||
continue
|
||||
if not isinstance(getattr(module, key), allowed_types):
|
||||
continue
|
||||
|
||||
if key not in cfg:
|
||||
cfg[key] = getattr(module, key)
|
||||
if key not in cfg:
|
||||
cfg[key] = getattr(module, key)
|
||||
|
||||
return cfg
|
||||
|
||||
|
|
|
@ -20,8 +20,8 @@ class ConfigManager:
|
|||
self.file = cfg_file
|
||||
self.data = {}
|
||||
|
||||
async def load_config(self):
|
||||
self.data = await self.file.load()
|
||||
async def load_config(self, completion: bool=True):
|
||||
self.data = await self.file.load(completion=completion)
|
||||
|
||||
async def dump_config(self):
|
||||
await self.file.save(self.data)
|
||||
|
@ -30,7 +30,7 @@ class ConfigManager:
|
|||
self.file.save_sync(self.data)
|
||||
|
||||
|
||||
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager:
|
||||
async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
|
||||
"""加载Python模块配置文件"""
|
||||
cfg_inst = pymodule.PythonModuleConfigFile(
|
||||
config_name,
|
||||
|
@ -38,12 +38,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
|
|||
)
|
||||
|
||||
cfg_mgr = ConfigManager(cfg_inst)
|
||||
await cfg_mgr.load_config()
|
||||
await cfg_mgr.load_config(completion=completion)
|
||||
|
||||
return cfg_mgr
|
||||
|
||||
|
||||
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager:
|
||||
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
|
||||
"""加载JSON配置文件"""
|
||||
cfg_inst = json_file.JSONConfigFile(
|
||||
config_name,
|
||||
|
@ -52,6 +52,6 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
|
|||
)
|
||||
|
||||
cfg_mgr = ConfigManager(cfg_inst)
|
||||
await cfg_mgr.load_config()
|
||||
await cfg_mgr.load_config(completion=completion)
|
||||
|
||||
return cfg_mgr
|
19
pkg/config/migrations/m006_vision_config.py
Normal file
19
pkg/config/migrations/m006_vision_config.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("vision-config", 6)
|
||||
class VisionConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return "enable-vision" not in self.ap.provider_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
if "enable-vision" not in self.ap.provider_cfg.data:
|
||||
self.ap.provider_cfg.data["enable-vision"] = False
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
|
@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
|
|||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self) -> dict:
|
||||
async def load(self, completion: bool=True) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
|
|
|
@ -14,6 +14,7 @@ required_deps = {
|
|||
"yaml": "pyyaml",
|
||||
"aiohttp": "aiohttp",
|
||||
"psutil": "psutil",
|
||||
"async_lru": "async-lru",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ class Query(pydantic.BaseModel):
|
|||
resp_messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""由Process阶段生成的回复消息对象列表"""
|
||||
|
||||
resp_message_chain: typing.Optional[mirai.MessageChain] = None
|
||||
resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None
|
||||
"""回复消息链,从resp_messages包装而得"""
|
||||
|
||||
class Config:
|
||||
|
|
|
@ -15,7 +15,6 @@ from ...provider.sysprompt import sysprompt as llm_prompt_mgr
|
|||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...platform import manager as im_mgr
|
||||
|
||||
|
||||
@stage.stage_class("BuildAppStage")
|
||||
class BuildAppStage(stage.BootingStage):
|
||||
"""构建应用阶段
|
||||
|
@ -83,7 +82,6 @@ class BuildAppStage(stage.BootingStage):
|
|||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.platform_mgr = im_mgr_inst
|
||||
|
@ -92,5 +90,6 @@ class BuildAppStage(stage.BootingStage):
|
|||
await stage_mgr.initialize()
|
||||
ap.stage_mgr = stage_mgr
|
||||
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
|
|
|
@ -12,11 +12,11 @@ class LoadConfigStage(stage.BootingStage):
|
|||
async def run(self, ap: app.Application):
|
||||
"""启动
|
||||
"""
|
||||
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json")
|
||||
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json")
|
||||
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
|
||||
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
|
||||
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")
|
||||
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False)
|
||||
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False)
|
||||
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False)
|
||||
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False)
|
||||
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False)
|
||||
|
||||
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
|
||||
await ap.plugin_setting_meta.dump_config()
|
||||
|
|
|
@ -4,7 +4,7 @@ import importlib
|
|||
|
||||
from .. import stage, app
|
||||
from ...config import migration
|
||||
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
|
||||
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion, m006_vision_config
|
||||
from ...config.migrations import m005_deepseek_cfg_completion
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,10 @@ from ...config import manager as cfg_mgr
|
|||
|
||||
@stage.stage_class('BanSessionCheckStage')
|
||||
class BanSessionCheckStage(stage.PipelineStage):
|
||||
"""访问控制处理阶段"""
|
||||
"""访问控制处理阶段
|
||||
|
||||
仅检查query中群号或个人号是否在访问控制列表中。
|
||||
"""
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
|
@ -9,12 +9,24 @@ from ...core import entities as core_entities
|
|||
from ...config import manager as cfg_mgr
|
||||
from . import filter as filter_model, entities as filter_entities
|
||||
from .filters import cntignore, banwords, baiduexamine
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
@stage.stage_class('PostContentFilterStage')
|
||||
@stage.stage_class('PreContentFilterStage')
|
||||
class ContentFilterStage(stage.PipelineStage):
|
||||
"""内容过滤阶段"""
|
||||
"""内容过滤阶段
|
||||
|
||||
前置:
|
||||
检查消息是否符合规则,不符合则拦截。
|
||||
改写:
|
||||
message_chain
|
||||
|
||||
后置:
|
||||
检查AI回复消息是否符合规则,可能进行改写,不符合则拦截。
|
||||
改写:
|
||||
query.resp_messages
|
||||
"""
|
||||
|
||||
filter_chain: list[filter_model.ContentFilter]
|
||||
|
||||
|
@ -130,6 +142,21 @@ class ContentFilterStage(stage.PipelineStage):
|
|||
"""处理
|
||||
"""
|
||||
if stage_inst_name == 'PreContentFilterStage':
|
||||
|
||||
contain_non_text = False
|
||||
|
||||
for me in query.message_chain:
|
||||
if not isinstance(me, mirai.Plain):
|
||||
contain_non_text = True
|
||||
break
|
||||
|
||||
if contain_non_text:
|
||||
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
return await self._pre_process(
|
||||
str(query.message_chain).strip(),
|
||||
query
|
||||
|
|
|
@ -4,6 +4,8 @@ import enum
|
|||
|
||||
import pydantic
|
||||
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
class ResultLevel(enum.Enum):
|
||||
"""结果等级"""
|
||||
|
@ -38,7 +40,7 @@ class FilterResult(pydantic.BaseModel):
|
|||
"""
|
||||
|
||||
replacement: str
|
||||
"""替换后的消息
|
||||
"""替换后的文本消息
|
||||
|
||||
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
|
||||
若没有修改内容,也需要返回原消息。
|
||||
|
|
|
@ -5,6 +5,7 @@ import typing
|
|||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
||||
|
@ -63,7 +64,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
|||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
|
@ -71,6 +72,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
|||
|
||||
Args:
|
||||
message (str): 需要检查的内容
|
||||
image_url (str): 要检查的图片的 URL
|
||||
|
||||
Returns:
|
||||
entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档
|
||||
|
|
|
@ -8,7 +8,7 @@ from ....config import manager as cfg_mgr
|
|||
|
||||
@filter_model.filter_class("ban-word-filter")
|
||||
class BanWordFilter(filter_model.ContentFilter):
|
||||
"""根据内容禁言"""
|
||||
"""根据内容过滤"""
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
|
@ -16,6 +16,9 @@ from ...config import manager as cfg_mgr
|
|||
@stage.stage_class("LongTextProcessStage")
|
||||
class LongTextProcessStage(stage.PipelineStage):
|
||||
"""长消息处理阶段
|
||||
|
||||
改写:
|
||||
- resp_message_chain
|
||||
"""
|
||||
|
||||
strategy_impl: strategy.LongTextStrategy
|
||||
|
@ -59,15 +62,15 @@ class LongTextProcessStage(stage.PipelineStage):
|
|||
# 检查是否包含非 Plain 组件
|
||||
contains_non_plain = False
|
||||
|
||||
for msg in query.resp_message_chain:
|
||||
for msg in query.resp_message_chain[-1]:
|
||||
if not isinstance(msg, Plain):
|
||||
contains_non_plain = True
|
||||
break
|
||||
|
||||
if contains_non_plain:
|
||||
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
|
||||
elif len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
|
||||
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
|
||||
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
|
||||
query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
|
|
@ -43,7 +43,7 @@ class QueryPool:
|
|||
message_event=message_event,
|
||||
message_chain=message_chain,
|
||||
resp_messages=[],
|
||||
resp_message_chain=None,
|
||||
resp_message_chain=[],
|
||||
adapter=adapter
|
||||
)
|
||||
self.queries.append(query)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...provider import entities as llm_entities
|
||||
|
@ -9,6 +11,16 @@ from ...plugin import events
|
|||
@stage.stage_class("PreProcessor")
|
||||
class PreProcessor(stage.PipelineStage):
|
||||
"""请求预处理阶段
|
||||
|
||||
签出会话、prompt、上文、模型、内容函数。
|
||||
|
||||
改写:
|
||||
- session
|
||||
- prompt
|
||||
- messages
|
||||
- user_message
|
||||
- use_model
|
||||
- use_funcs
|
||||
"""
|
||||
|
||||
async def process(
|
||||
|
@ -27,21 +39,42 @@ class PreProcessor(stage.PipelineStage):
|
|||
query.prompt = conversation.prompt.copy()
|
||||
query.messages = conversation.messages.copy()
|
||||
|
||||
query.user_message = llm_entities.Message(
|
||||
role='user',
|
||||
content=str(query.message_chain).strip()
|
||||
)
|
||||
|
||||
query.use_model = conversation.use_model
|
||||
|
||||
query.use_funcs = conversation.use_funcs
|
||||
query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None
|
||||
|
||||
|
||||
# 检查vision是否启用,没启用就删除所有图片
|
||||
if not self.ap.provider_cfg.data['enable-vision'] or not query.use_model.vision_supported:
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
for me in msg.content:
|
||||
if me.type == 'image_url':
|
||||
msg.content.remove(me)
|
||||
|
||||
content_list = []
|
||||
|
||||
for me in query.message_chain:
|
||||
if isinstance(me, mirai.Plain):
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_text(me.text)
|
||||
)
|
||||
elif isinstance(me, mirai.Image):
|
||||
if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported:
|
||||
if me.url is not None:
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_image_url(str(me.url))
|
||||
)
|
||||
|
||||
query.user_message = llm_entities.Message( # TODO 适配多模态输入
|
||||
role='user',
|
||||
content=content_list
|
||||
)
|
||||
# =========== 触发事件 PromptPreProcessing
|
||||
session = query.session
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PromptPreProcessing(
|
||||
session_name=f'{session.launcher_type.value}_{session.launcher_id}',
|
||||
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
default_prompt=query.prompt.messages,
|
||||
prompt=query.messages,
|
||||
query=query
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import typing
|
||||
import time
|
||||
import traceback
|
||||
import json
|
||||
|
||||
import mirai
|
||||
|
||||
|
@ -70,17 +71,13 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
mirai.Plain(event_ctx.event.alter)
|
||||
])
|
||||
|
||||
query.messages.append(
|
||||
query.user_message
|
||||
)
|
||||
|
||||
text_length = 0
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
|
||||
async for result in query.use_model.requester.request(query):
|
||||
async for result in self.runner(query):
|
||||
query.resp_messages.append(result)
|
||||
|
||||
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
||||
|
@ -92,6 +89,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
query.session.using_conversation.messages.append(query.user_message)
|
||||
query.session.using_conversation.messages.extend(query.resp_messages)
|
||||
except Exception as e:
|
||||
|
||||
self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}')
|
||||
|
@ -104,8 +104,6 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
debug_notice=traceback.format_exc()
|
||||
)
|
||||
finally:
|
||||
query.session.using_conversation.messages.append(query.user_message)
|
||||
query.session.using_conversation.messages.extend(query.resp_messages)
|
||||
|
||||
await self.ap.ctr_mgr.usage.post_query_record(
|
||||
session_type=query.session.launcher_type.value,
|
||||
|
@ -115,4 +113,65 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
model_name=query.use_model.name,
|
||||
response_seconds=int(time.time() - start_time),
|
||||
retry_times=-1,
|
||||
)
|
||||
)
|
||||
|
||||
async def runner(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
|
||||
|
||||
这是临时处理方案,后续可能改为使用LangChain或者自研的工作流处理器
|
||||
"""
|
||||
await query.use_model.requester.preprocess(query)
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
||||
|
||||
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
||||
while pending_tool_calls:
|
||||
for tool_call in pending_tool_calls:
|
||||
try:
|
||||
func = tool_call.function
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||
query, func.name, parameters
|
||||
)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
req_messages.append(msg)
|
||||
except Exception as e:
|
||||
# 工具调用出错,添加一个报错信息到 req_messages
|
||||
err_msg = llm_entities.Message(
|
||||
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield err_msg
|
||||
|
||||
req_messages.append(err_msg)
|
||||
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
||||
|
|
|
@ -80,9 +80,6 @@ class CommandHandler(handler.MessageHandler):
|
|||
session=session
|
||||
):
|
||||
if ret.error is not None:
|
||||
# query.resp_message_chain = mirai.MessageChain([
|
||||
# mirai.Plain(str(ret.error))
|
||||
# ])
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
role='command',
|
||||
|
@ -96,18 +93,28 @@ class CommandHandler(handler.MessageHandler):
|
|||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif ret.text is not None:
|
||||
# query.resp_message_chain = mirai.MessageChain([
|
||||
# mirai.Plain(ret.text)
|
||||
# ])
|
||||
elif ret.text is not None or ret.image_url is not None:
|
||||
|
||||
content: list[llm_entities.ContentElement]= []
|
||||
|
||||
if ret.text is not None:
|
||||
content.append(
|
||||
llm_entities.ContentElement.from_text(ret.text)
|
||||
)
|
||||
|
||||
if ret.image_url is not None:
|
||||
content.append(
|
||||
llm_entities.ContentElement.from_image_url(ret.image_url)
|
||||
)
|
||||
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
role='command',
|
||||
content=ret.text,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}')
|
||||
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
|
|
@ -11,7 +11,13 @@ from ...config import manager as cfg_mgr
|
|||
|
||||
@stage.stage_class("MessageProcessor")
|
||||
class Processor(stage.PipelineStage):
|
||||
"""请求实际处理阶段"""
|
||||
"""请求实际处理阶段
|
||||
|
||||
通过命令处理器和聊天处理器处理消息。
|
||||
|
||||
改写:
|
||||
- resp_messages
|
||||
"""
|
||||
|
||||
cmd_handler: handler.MessageHandler
|
||||
|
||||
|
|
|
@ -11,7 +11,10 @@ from ...core import entities as core_entities
|
|||
@stage.stage_class("RequireRateLimitOccupancy")
|
||||
@stage.stage_class("ReleaseRateLimitOccupancy")
|
||||
class RateLimit(stage.PipelineStage):
|
||||
"""限速器控制阶段"""
|
||||
"""限速器控制阶段
|
||||
|
||||
不改写query,只检查是否需要限速。
|
||||
"""
|
||||
|
||||
algo: algo.ReteLimitAlgo
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class SendResponseBackStage(stage.PipelineStage):
|
|||
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
query.resp_message_chain,
|
||||
query.resp_message_chain[-1],
|
||||
adapter=query.adapter
|
||||
)
|
||||
|
||||
|
|
|
@ -14,9 +14,12 @@ from ...config import manager as cfg_mgr
|
|||
@stage.stage_class("GroupRespondRuleCheckStage")
|
||||
class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
"""群组响应规则检查器
|
||||
|
||||
仅检查群消息是否符合规则。
|
||||
"""
|
||||
|
||||
rule_matchers: list[rule.GroupRespondRule]
|
||||
"""检查器实例"""
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化检查器
|
||||
|
@ -31,7 +34,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
|||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
|
||||
if query.launcher_type.value != 'group':
|
||||
if query.launcher_type.value != 'group': # 只处理群消息
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
|
|
|
@ -17,17 +17,17 @@ from .ratelimit import ratelimit
|
|||
|
||||
# 请求处理阶段顺序
|
||||
stage_order = [
|
||||
"GroupRespondRuleCheckStage",
|
||||
"BanSessionCheckStage",
|
||||
"PreContentFilterStage",
|
||||
"PreProcessor",
|
||||
"RequireRateLimitOccupancy",
|
||||
"MessageProcessor",
|
||||
"ReleaseRateLimitOccupancy",
|
||||
"PostContentFilterStage",
|
||||
"ResponseWrapper",
|
||||
"LongTextProcessStage",
|
||||
"SendResponseBackStage",
|
||||
"GroupRespondRuleCheckStage", # 群响应规则检查
|
||||
"BanSessionCheckStage", # 封禁会话检查
|
||||
"PreContentFilterStage", # 内容过滤前置阶段
|
||||
"PreProcessor", # 预处理器
|
||||
"RequireRateLimitOccupancy", # 请求速率限制占用
|
||||
"MessageProcessor", # 处理器
|
||||
"ReleaseRateLimitOccupancy", # 释放速率限制占用
|
||||
"PostContentFilterStage", # 内容过滤后置阶段
|
||||
"ResponseWrapper", # 响应包装器
|
||||
"LongTextProcessStage", # 长文本处理
|
||||
"SendResponseBackStage", # 发送响应
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -14,6 +14,13 @@ from ...plugin import events
|
|||
|
||||
@stage.stage_class("ResponseWrapper")
|
||||
class ResponseWrapper(stage.PipelineStage):
|
||||
"""回复包装阶段
|
||||
|
||||
把回复的 message 包装成人类识读的形式。
|
||||
|
||||
改写:
|
||||
- resp_message_chain
|
||||
"""
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
@ -27,17 +34,19 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
"""
|
||||
|
||||
if query.resp_messages[-1].role == 'command':
|
||||
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content)
|
||||
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif query.resp_messages[-1].role == 'plugin':
|
||||
if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
|
||||
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
|
||||
else:
|
||||
query.resp_message_chain = query.resp_messages[-1].content
|
||||
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
|
||||
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
|
||||
# else:
|
||||
# query.resp_message_chain.append(query.resp_messages[-1].content)
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
@ -52,7 +61,7 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
reply_text = ''
|
||||
|
||||
if result.content is not None: # 有内容
|
||||
reply_text = result.content
|
||||
reply_text = str(result.get_content_mirai_message_chain())
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
|
@ -76,11 +85,11 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
query.resp_message_chain.append(result.get_content_mirai_message_chain())
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
@ -93,7 +102,7 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
|
||||
|
||||
if self.ap.platform_cfg.data['track-function-calls']:
|
||||
|
||||
|
@ -119,13 +128,13 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
)
|
||||
|
|
|
@ -21,6 +21,39 @@ class ToolCall(pydantic.BaseModel):
|
|||
function: FunctionCall
|
||||
|
||||
|
||||
class ImageURLContentObject(pydantic.BaseModel):
|
||||
url: str
|
||||
|
||||
def __str__(self):
|
||||
return self.url[:128] + ('...' if len(self.url) > 128 else '')
|
||||
|
||||
|
||||
class ContentElement(pydantic.BaseModel):
|
||||
|
||||
type: str
|
||||
"""内容类型"""
|
||||
|
||||
text: typing.Optional[str] = None
|
||||
|
||||
image_url: typing.Optional[ImageURLContentObject] = None
|
||||
|
||||
def __str__(self):
|
||||
if self.type == 'text':
|
||||
return self.text
|
||||
elif self.type == 'image_url':
|
||||
return f'[图片]({self.image_url})'
|
||||
else:
|
||||
return '未知内容'
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str):
|
||||
return cls(type='text', text=text)
|
||||
|
||||
@classmethod
|
||||
def from_image_url(cls, image_url: str):
|
||||
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
|
||||
|
||||
|
||||
class Message(pydantic.BaseModel):
|
||||
"""消息"""
|
||||
|
||||
|
@ -30,12 +63,9 @@ class Message(pydantic.BaseModel):
|
|||
name: typing.Optional[str] = None
|
||||
"""名称,仅函数调用返回时设置"""
|
||||
|
||||
content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None
|
||||
content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None
|
||||
"""内容"""
|
||||
|
||||
function_call: typing.Optional[FunctionCall] = None
|
||||
"""函数调用,不再受支持,请使用tool_calls"""
|
||||
|
||||
tool_calls: typing.Optional[list[ToolCall]] = None
|
||||
"""工具调用"""
|
||||
|
||||
|
@ -43,10 +73,38 @@ class Message(pydantic.BaseModel):
|
|||
|
||||
def readable_str(self) -> str:
|
||||
if self.content is not None:
|
||||
return str(self.content)
|
||||
elif self.function_call is not None:
|
||||
return f'{self.function_call.name}({self.function_call.arguments})'
|
||||
return str(self.role) + ": " + str(self.get_content_mirai_message_chain())
|
||||
elif self.tool_calls is not None:
|
||||
return f'调用工具: {self.tool_calls[0].id}'
|
||||
else:
|
||||
return '未知消息'
|
||||
|
||||
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None:
|
||||
"""将内容转换为 Mirai MessageChain 对象
|
||||
|
||||
Args:
|
||||
prefix_text (str): 首个文字组件的前缀文本
|
||||
"""
|
||||
|
||||
if self.content is None:
|
||||
return None
|
||||
elif isinstance(self.content, str):
|
||||
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)])
|
||||
elif isinstance(self.content, list):
|
||||
mc = []
|
||||
for ce in self.content:
|
||||
if ce.type == 'text':
|
||||
mc.append(mirai.Plain(ce.text))
|
||||
elif ce.type == 'image':
|
||||
mc.append(mirai.Image(url=ce.image_url))
|
||||
|
||||
# 找第一个文字组件
|
||||
if prefix_text:
|
||||
for i, c in enumerate(mc):
|
||||
if isinstance(c, mirai.Plain):
|
||||
mc[i] = mirai.Plain(prefix_text+c.text)
|
||||
break
|
||||
else:
|
||||
mc.insert(0, mirai.Plain(prefix_text))
|
||||
|
||||
return mirai.MessageChain(mc)
|
||||
|
|
|
@ -6,6 +6,8 @@ import typing
|
|||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from . import entities as modelmgr_entities
|
||||
from ..tools import entities as tools_entities
|
||||
|
||||
|
||||
preregistered_requesters: list[typing.Type[LLMAPIRequester]] = []
|
||||
|
@ -33,20 +35,31 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(
|
||||
async def preprocess(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求API
|
||||
):
|
||||
"""预处理
|
||||
|
||||
在这里处理特定API对Query对象的兼容性问题。
|
||||
"""
|
||||
pass
|
||||
|
||||
对话前文可以从 query 对象中获取。
|
||||
可以多次yield消息对象。
|
||||
@abc.abstractmethod
|
||||
async def call(
|
||||
self,
|
||||
model: modelmgr_entities.LLMModelInfo,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
"""调用API
|
||||
|
||||
Args:
|
||||
query (core_entities.Query): 本次请求的上下文对象
|
||||
model (modelmgr_entities.LLMModelInfo): 使用的模型信息
|
||||
messages (typing.List[llm_entities.Message]): 消息对象列表
|
||||
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
|
||||
|
||||
Yields:
|
||||
pkg.provider.entities.Message: 返回消息对象
|
||||
Returns:
|
||||
llm_entities.Message: 返回消息对象
|
||||
"""
|
||||
raise NotImplementedError
|
||||
pass
|
||||
|
|
|
@ -27,47 +27,60 @@ class AnthropicMessages(api.LLMAPIRequester):
|
|||
proxies=self.ap.proxy_mgr.get_forward_proxies()
|
||||
)
|
||||
|
||||
async def request(
|
||||
async def call(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
self.client.api_key = query.use_model.token_mgr.get_token()
|
||||
model: entities.LLMModelInfo,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
|
||||
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
|
||||
args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name
|
||||
args["model"] = model.name if model.model_name is None else model.model_name
|
||||
|
||||
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != ""
|
||||
] + [m.dict(exclude_none=True) for m in query.messages]
|
||||
# 处理消息
|
||||
|
||||
# 删除所有 role=system & content='' 的消息
|
||||
req_messages = [
|
||||
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
|
||||
]
|
||||
# system
|
||||
system_role_message = None
|
||||
|
||||
# 检查是否有 role=system 的消息,若有,改为 role=user,并在后面加一个 role=assistant 的消息
|
||||
system_role_index = []
|
||||
for i, m in enumerate(req_messages):
|
||||
if m["role"] == "system":
|
||||
system_role_index.append(i)
|
||||
m["role"] = "user"
|
||||
for i, m in enumerate(messages):
|
||||
if m.role == "system":
|
||||
system_role_message = m
|
||||
|
||||
if system_role_index:
|
||||
for i in system_role_index[::-1]:
|
||||
req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."})
|
||||
messages.pop(i)
|
||||
break
|
||||
|
||||
# 忽略掉空消息,用户可能发送空消息,而上层未过滤
|
||||
req_messages = [
|
||||
m for m in req_messages if m["content"].strip() != ""
|
||||
]
|
||||
if isinstance(system_role_message, llm_entities.Message) \
|
||||
and isinstance(system_role_message.content, str):
|
||||
args['system'] = system_role_message.content
|
||||
|
||||
# 其他消息
|
||||
# req_messages = [
|
||||
# m.dict(exclude_none=True) for m in messages \
|
||||
# if (isinstance(m.content, str) and m.content.strip() != "") \
|
||||
# or (isinstance(m.content, list) and )
|
||||
# ]
|
||||
# 暂时不支持vision,仅保留纯文字的content
|
||||
req_messages = []
|
||||
|
||||
for m in messages:
|
||||
if isinstance(m.content, str) and m.content.strip() != "":
|
||||
req_messages.append(m.dict(exclude_none=True))
|
||||
elif isinstance(m.content, list):
|
||||
# 删除m.content中的type!=text的元素
|
||||
m.content = [
|
||||
c for c in m.content if c.get("type") == "text"
|
||||
]
|
||||
|
||||
if len(m.content) > 0:
|
||||
req_messages.append(m.dict(exclude_none=True))
|
||||
|
||||
args["messages"] = req_messages
|
||||
|
||||
try:
|
||||
|
||||
resp = await self.client.messages.create(**args)
|
||||
|
||||
yield llm_entities.Message(
|
||||
return llm_entities.Message(
|
||||
content=resp.content[0].text,
|
||||
role=resp.role
|
||||
)
|
||||
|
@ -79,4 +92,4 @@ class AnthropicMessages(api.LLMAPIRequester):
|
|||
if 'model: ' in str(e):
|
||||
raise errors.RequesterError(f'模型无效: {e.message}')
|
||||
else:
|
||||
raise errors.RequesterError(f'请求地址无效: {e.message}')
|
||||
raise errors.RequesterError(f'请求地址无效: {e.message}')
|
||||
|
|
|
@ -3,16 +3,20 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import typing
|
||||
import json
|
||||
import base64
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import openai
|
||||
import openai.types.chat.chat_completion as chat_completion
|
||||
import httpx
|
||||
import aiohttp
|
||||
import async_lru
|
||||
|
||||
from .. import api, entities, errors
|
||||
from ....core import entities as core_entities, app
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
from ....utils import image
|
||||
|
||||
|
||||
@api.requester_class("openai-chat-completions")
|
||||
|
@ -43,7 +47,6 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
|
|||
self,
|
||||
args: dict,
|
||||
) -> chat_completion.ChatCompletion:
|
||||
self.ap.logger.debug(f"req chat_completion with args {args}")
|
||||
return await self.client.chat.completions.create(**args)
|
||||
|
||||
async def _make_msg(
|
||||
|
@ -67,14 +70,22 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
|
|||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
if use_model.tool_call_supported:
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
messages = req_messages.copy()
|
||||
|
||||
# 检查vision
|
||||
for msg in messages:
|
||||
if 'content' in msg and isinstance(msg["content"], list):
|
||||
for me in msg["content"]:
|
||||
if me["type"] == "image_url":
|
||||
me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url'])
|
||||
|
||||
args["messages"] = messages
|
||||
|
||||
# 发送请求
|
||||
|
@ -84,73 +95,19 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
|
|||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
|
||||
async def _request(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求"""
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
|
||||
async def call(
|
||||
self,
|
||||
model: entities.LLMModelInfo,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != ""
|
||||
] + [m.dict(exclude_none=True) for m in query.messages]
|
||||
m.dict(exclude_none=True) for m in messages
|
||||
]
|
||||
|
||||
# req_messages.append({"role": "user", "content": str(query.message_chain)})
|
||||
|
||||
# 首次请求
|
||||
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg.dict(exclude_none=True))
|
||||
|
||||
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
||||
while pending_tool_calls:
|
||||
for tool_call in pending_tool_calls:
|
||||
try:
|
||||
func = tool_call.function
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||
query, func.name, parameters
|
||||
)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
req_messages.append(msg.dict(exclude_none=True))
|
||||
except Exception as e:
|
||||
# 出错,添加一个报错信息到 req_messages
|
||||
err_msg = llm_entities.Message(
|
||||
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield err_msg
|
||||
|
||||
req_messages.append(
|
||||
err_msg.dict(exclude_none=True)
|
||||
)
|
||||
|
||||
# 处理完所有调用,继续请求
|
||||
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg.dict(exclude_none=True))
|
||||
|
||||
async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]:
|
||||
try:
|
||||
async for msg in self._request(query):
|
||||
yield msg
|
||||
return await self._closure(req_messages, model, funcs)
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
|
@ -163,6 +120,16 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
|
|||
except openai.NotFoundError as e:
|
||||
raise errors.RequesterError(f'请求路径错误: {e.message}')
|
||||
except openai.RateLimitError as e:
|
||||
raise errors.RequesterError(f'请求过于频繁: {e.message}')
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||
except openai.APIError as e:
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
|
||||
@async_lru.alru_cache(maxsize=128)
|
||||
async def get_base64_str(
|
||||
self,
|
||||
original_url: str,
|
||||
) -> str:
|
||||
|
||||
base64_image = await image.qq_image_url_to_base64(original_url)
|
||||
|
||||
return f"data:image/jpeg;base64,{base64_image}"
|
||||
|
|
|
@ -3,7 +3,10 @@ from __future__ import annotations
|
|||
from ....core import app
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import api
|
||||
from .. import api, entities, errors
|
||||
from ....core import entities as core_entities, app
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
|
||||
@api.requester_class("deepseek-chat-completions")
|
||||
|
@ -12,4 +15,39 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions']
|
||||
self.ap = ap
|
||||
self.ap = ap
|
||||
|
||||
async def _closure(
|
||||
self,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
|
||||
# deepseek 不支持多模态,把content都转换成纯文字
|
||||
for m in messages:
|
||||
if 'content' in m and isinstance(m["content"], list):
|
||||
m["content"] = " ".join([c["text"] for c in m["content"]])
|
||||
|
||||
args["messages"] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
|
@ -3,7 +3,10 @@ from __future__ import annotations
|
|||
from ....core import app
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import api
|
||||
from .. import api, entities, errors
|
||||
from ....core import entities as core_entities, app
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
|
||||
@api.requester_class("moonshot-chat-completions")
|
||||
|
@ -13,3 +16,41 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|||
def __init__(self, ap: app.Application):
|
||||
self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions']
|
||||
self.ap = ap
|
||||
|
||||
async def _closure(
|
||||
self,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
|
||||
# deepseek 不支持多模态,把content都转换成纯文字
|
||||
for m in messages:
|
||||
if 'content' in m and isinstance(m["content"], list):
|
||||
m["content"] = " ".join([c["text"] for c in m["content"]])
|
||||
|
||||
# 删除空的
|
||||
messages = [m for m in messages if m["content"].strip() != ""]
|
||||
|
||||
args["messages"] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
|
@ -21,5 +21,7 @@ class LLMModelInfo(pydantic.BaseModel):
|
|||
|
||||
tool_call_supported: typing.Optional[bool] = False
|
||||
|
||||
vision_supported: typing.Optional[bool] = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
|
|
@ -37,7 +37,7 @@ class ModelManager:
|
|||
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
|
||||
# 初始化token_mgr, requester
|
||||
for k, v in self.ap.provider_cfg.data['keys'].items():
|
||||
self.token_mgrs[k] = token.TokenManager(k, v)
|
||||
|
@ -83,7 +83,8 @@ class ModelManager:
|
|||
model_name=None,
|
||||
token_mgr=self.token_mgrs[model['token_mgr']],
|
||||
requester=self.requesters[model['requester']],
|
||||
tool_call_supported=model['tool_call_supported']
|
||||
tool_call_supported=model['tool_call_supported'],
|
||||
vision_supported=model['vision_supported']
|
||||
)
|
||||
break
|
||||
|
||||
|
@ -95,13 +96,15 @@ class ModelManager:
|
|||
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
|
||||
requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
|
||||
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
|
||||
vision_supported = model.get('vision_supported', default_model_info.vision_supported)
|
||||
|
||||
model_info = entities.LLMModelInfo(
|
||||
name=model['name'],
|
||||
model_name=model_name,
|
||||
token_mgr=token_mgr,
|
||||
requester=requester,
|
||||
tool_call_supported=tool_call_supported
|
||||
tool_call_supported=tool_call_supported,
|
||||
vision_supported=vision_supported
|
||||
)
|
||||
self.model_list.append(model_info)
|
||||
|
||||
|
|
41
pkg/utils/image.py
Normal file
41
pkg/utils/image.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
import base64
|
||||
import typing
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
import ssl
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
async def qq_image_url_to_base64(
|
||||
image_url: str
|
||||
) -> str:
|
||||
"""将QQ图片URL转为base64
|
||||
|
||||
Args:
|
||||
image_url (str): QQ图片URL
|
||||
|
||||
Returns:
|
||||
str: base64编码
|
||||
"""
|
||||
parsed = urlparse(image_url)
|
||||
query = parse_qs(parsed.query)
|
||||
|
||||
# Flatten the query dictionary
|
||||
query = {k: v[0] for k, v in query.items()}
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=False) as session:
|
||||
async with session.get(
|
||||
f"http://{parsed.netloc}{parsed.path}",
|
||||
params=query,
|
||||
ssl=ssl_context
|
||||
) as resp:
|
||||
resp.raise_for_status() # 检查HTTP错误
|
||||
file_bytes = await resp.read()
|
||||
|
||||
base64_str = base64.b64encode(file_bytes).decode()
|
||||
|
||||
return base64_str
|
|
@ -13,4 +13,5 @@ aiohttp
|
|||
pydantic
|
||||
websockets
|
||||
urllib3
|
||||
psutil
|
||||
psutil
|
||||
async-lru
|
|
@ -4,23 +4,73 @@
|
|||
"name": "default",
|
||||
"requester": "openai-chat-completions",
|
||||
"token_mgr": "openai",
|
||||
"tool_call_supported": false
|
||||
"tool_call_supported": false,
|
||||
"vision_supported": false
|
||||
},
|
||||
{
|
||||
"name": "gpt-3.5-turbo-0125",
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": false
|
||||
},
|
||||
{
|
||||
"name": "gpt-3.5-turbo",
|
||||
"tool_call_supported": true
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": false
|
||||
},
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"tool_call_supported": true
|
||||
"name": "gpt-3.5-turbo-1106",
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": false
|
||||
},
|
||||
{
|
||||
"name": "gpt-4-turbo",
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": true
|
||||
},
|
||||
{
|
||||
"name": "gpt-4-turbo-2024-04-09",
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": true
|
||||
},
|
||||
{
|
||||
"name": "gpt-4-turbo-preview",
|
||||
"tool_call_supported": true
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": true
|
||||
},
|
||||
{
|
||||
"name": "gpt-4-0125-preview",
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": true
|
||||
},
|
||||
{
|
||||
"name": "gpt-4-1106-preview",
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": true
|
||||
},
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": true
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o",
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": true
|
||||
},
|
||||
{
|
||||
"name": "gpt-4-0613",
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": true
|
||||
},
|
||||
{
|
||||
"name": "gpt-4-32k",
|
||||
"tool_call_supported": true
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": true
|
||||
},
|
||||
{
|
||||
"name": "gpt-4-32k-0613",
|
||||
"tool_call_supported": true,
|
||||
"vision_supported": true
|
||||
},
|
||||
{
|
||||
"model_name": "SparkDesk",
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
{
|
||||
"enable-chat": true,
|
||||
"enable-vision": true,
|
||||
"keys": {
|
||||
"openai": [
|
||||
"sk-1234567890"
|
||||
|
|
Loading…
Reference in New Issue
Block a user